Unverified Commit 1cf54e86 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Shuffle in thread (#13)

* adding in-thread shuffle

* update softmax example

* refactor grid gemm

* refactor gemm: layouts

* bug fix

* clean

* clean
parent 2837e6b3
......@@ -51,6 +51,10 @@ int main(int argc, char* argv[])
using AccDataType = float;
using CDataType = ck::half_t;
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
ck::index_t M = 3328;
ck::index_t N = 4096;
ck::index_t K = 4096;
......@@ -62,14 +66,24 @@ int main(int argc, char* argv[])
K = std::stoi(argv[3]);
}
std::array<ck::index_t, 2> a_lengths{M, K};
std::array<ck::index_t, 2> a_strides{K, 1};
const ck::index_t Lda = std::is_same_v<ALayout, ck::tensor_layout::gemm::RowMajor> ? K : M;
const ck::index_t Ldb = std::is_same_v<BLayout, ck::tensor_layout::gemm::ColumnMajor> ? K : N;
const ck::index_t Ldc = std::is_same_v<CLayout, ck::tensor_layout::gemm::RowMajor> ? N : M;
const auto a_lengths = std::array<ck::index_t, 2>{M, K};
const auto a_strides = std::is_same_v<ALayout, ck::tensor_layout::gemm::RowMajor>
? std::array<ck::index_t, 2>{Lda, 1}
: std::array<ck::index_t, 2>{1, Lda};
std::array<ck::index_t, 2> b_lengths{N, K};
std::array<ck::index_t, 2> b_strides{K, 1};
const auto b_lengths = std::array<ck::index_t, 2>{N, K};
const auto b_strides = std::is_same_v<BLayout, ck::tensor_layout::gemm::ColumnMajor>
? std::array<ck::index_t, 2>{Ldb, 1}
: std::array<ck::index_t, 2>{1, Ldb};
std::array<ck::index_t, 2> c_lengths{M, N};
std::array<ck::index_t, 2> c_strides{N, 1};
const auto c_lengths = std::array<ck::index_t, 2>{M, N};
const auto c_strides = std::is_same_v<CLayout, ck::tensor_layout::gemm::RowMajor>
? std::array<ck::index_t, 2>{Ldc, 1}
: std::array<ck::index_t, 2>{1, Ldc};
// host verify
Tensor<ADataType> a_host(a_lengths, a_strides);
......@@ -90,11 +104,17 @@ int main(int argc, char* argv[])
a_buf.ToDevice(a_host.mData.data());
b_buf.ToDevice(b_host.mData.data());
// Alignment
constexpr ck::index_t kAAlignment = 32;
constexpr ck::index_t kBAlignment = 32;
constexpr ck::index_t kCAlignment = 32;
constexpr ck::index_t kBlockSize = 256;
constexpr ck::index_t kGemmMPerBlock = 256;
constexpr ck::index_t kGemmNPerBlock = 128;
constexpr ck::index_t kGemmKPerBlock = 32;
constexpr ck::index_t kBlockSize = 256;
ck::index_t kGridSize = (M / kGemmMPerBlock) * (N / kGemmNPerBlock);
std::cout << "grid size " << kGridSize << std::endl;
......@@ -107,12 +127,15 @@ int main(int argc, char* argv[])
BDataType,
AccDataType,
CDataType,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ALayout,
BLayout,
CLayout,
AElementFunction,
BElementFunction,
CElementFunction,
kAAlignment,
kBAlignment,
kCAlignment,
kBlockSize,
kGemmMPerBlock,
kGemmNPerBlock,
......@@ -130,9 +153,9 @@ int main(int argc, char* argv[])
M,
N,
K,
K,
K,
N,
Lda,
Ldb,
Ldc,
AElementFunction{},
BElementFunction{},
CElementFunction{});
......
......@@ -15,9 +15,9 @@
#include "ck/tile_program/warp_tile/warp_gemm.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck/tile_program/grid/grid_gemm.hpp"
#include "ck/tile_program/grid/grid_gemm_policy.hpp"
#include "ck/tile_program/grid/grid_gemm_problem.hpp"
#include "ck/tile_program/grid/grid_gemm_v1.hpp"
#include "ck/tile_program/grid/grid_gemm_v1_default_policy.hpp"
// C = A * B
template <typename ADataType,
......@@ -30,17 +30,16 @@ template <typename ADataType,
typename AElementFunction,
typename BElementFunction,
typename CElementFunction,
ck::index_t kBlockSize,
ck::index_t kMPerBlock,
ck::index_t kNPerBlock,
ck::index_t kKPerBlock>
ck::index_t kAAlignment,
ck::index_t kBAlignment,
ck::index_t kCAlignment,
ck::index_t kBlockSize_,
ck::index_t kMPerBlock_,
ck::index_t kNPerBlock_,
ck::index_t kKPerBlock_>
struct Gemm
{
static_assert(std::is_same_v<ALayout, ck::tensor_layout::gemm::RowMajor> &&
std::is_same_v<BLayout, ck::tensor_layout::gemm::ColumnMajor> &&
std::is_same_v<CLayout, ck::tensor_layout::gemm::RowMajor>);
using Problem = ck::tile_program::grid::GridGemmProblem<ADataType,
using GridGemmProblem = ck::tile_program::grid::GridGemmProblem<ADataType,
BDataType,
AccDataType,
CDataType,
......@@ -48,26 +47,60 @@ struct Gemm
BElementFunction,
CElementFunction>;
using Policy = ck::tile_program::grid::GridGemmPolicy<
struct GridGemmPolicy
{
static constexpr ck::index_t kBlockSize = kBlockSize_;
static constexpr ck::index_t kMPerBlock = kMPerBlock_;
static constexpr ck::index_t kNPerBlock = kNPerBlock_;
static constexpr ck::index_t kKPerBlock = kKPerBlock_;
template <typename Problem>
__host__ __device__ static constexpr auto MakeBlock2TileMap(ck::index_t NumTilesM,
ck::index_t NumTilesN)
{
using namespace ck;
const auto unmerge = make_merge_transform(make_tuple(NumTilesN, NumTilesM));
return [unmerge](index_t block_id) {
MultiIndex<2> unmerged;
unmerge.CalculateLowerIndex(unmerged, make_multi_index(block_id));
return make_multi_index(unmerged.At<1>(), unmerged.At<0>());
};
}
template <typename Problem>
__host__ __device__ static constexpr auto GetBlockGemmPipeline()
{
using namespace ck;
using namespace ck::tile_program;
using namespace ck::tile_program::block;
using BlockGemmPipelineProblem_ =
BlockGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
kBlockSize,
kMPerBlock,
kNPerBlock,
kKPerBlock,
ck::tile_program::block::BlockGemmPipelineAGmemBGmemCRegV2,
ck::Tuple<ck::tile_program::grid::DefaultBlock2TileMap,
ck::tile_program::block::BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy>>;
TileGemmShape<kMPerBlock, kNPerBlock, kKPerBlock>>;
using GridGemm = ck::tile_program::grid::GridGemm<Problem, Policy>;
return BlockGemmPipelineAGmemBGmemCRegV2<
BlockGemmPipelineProblem_,
BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy>{};
}
};
using GridGemm = ck::GridGemmV1<GridGemmProblem, GridGemmPolicy>;
__device__ void operator()(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t Lda,
ck::index_t Ldb,
ck::index_t Ldc,
const ck::index_t M,
const ck::index_t N,
const ck::index_t K,
const ck::index_t Lda,
const ck::index_t Ldb,
const ck::index_t Ldc,
const AElementFunction& a_element_func,
const BElementFunction& b_element_func,
const CElementFunction& c_element_func) const
......@@ -76,17 +109,63 @@ struct Gemm
using namespace ck::tile_program;
using namespace ck::tile_program::block;
// FIXME: assume RCR layout
const auto a_dram_grid = make_naive_tensor_view<AddressSpaceEnum::Global>(
p_a, make_tuple(M, K), make_tuple(Lda, 1), Number<32>{}, Number<1>{});
const auto a_dram = [&] {
if constexpr(is_same_v<ALayout, ck::tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<AddressSpaceEnum::Global>(
p_a, make_tuple(M, K), make_tuple(Lda, 1), Number<kAAlignment>{}, Number<1>{});
}
else
{
const auto a_k_m_desc = make_naive_tensor_view<AddressSpaceEnum::Global>(
p_a, make_tuple(K, M), make_tuple(Lda, 1), Number<kAAlignment>{}, Number<1>{});
const auto b_dram_grid = make_naive_tensor_view<AddressSpaceEnum::Global>(
p_b, make_tuple(N, K), make_tuple(Ldb, 1), Number<32>{}, Number<1>{});
return transform_tensor_view(
a_k_m_desc,
make_tuple(make_pass_through_transform(M), make_pass_through_transform(K)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}();
auto c_dram_grid = make_naive_tensor_view<AddressSpaceEnum::Global>(
p_c, make_tuple(M, N), make_tuple(Ldc, 1), Number<32>{}, Number<1>{});
const auto b_dram = [&] {
if constexpr(is_same_v<BLayout, ck::tensor_layout::gemm::ColumnMajor>)
{
return make_naive_tensor_view<AddressSpaceEnum::Global>(
p_b, make_tuple(N, K), make_tuple(Ldb, 1), Number<kBAlignment>{}, Number<1>{});
}
else
{
const auto b_k_n_desc = make_naive_tensor_view<AddressSpaceEnum::Global>(
p_b, make_tuple(K, N), make_tuple(Ldb, 1), Number<kBAlignment>{}, Number<1>{});
return transform_tensor_view(
b_k_n_desc,
make_tuple(make_pass_through_transform(N), make_pass_through_transform(K)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}();
const auto c_dram = [&] {
if constexpr(is_same_v<CLayout, ck::tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<AddressSpaceEnum::Global>(
p_c, make_tuple(M, N), make_tuple(Ldc, 1), Number<kCAlignment>{}, Number<1>{});
}
else
{
const auto c_n_m_desc = make_naive_tensor_view<AddressSpaceEnum::Global>(
p_c, make_tuple(N, M), make_tuple(Ldc, 1), Number<kCAlignment>{}, Number<1>{});
return transform_tensor_view(
c_n_m_desc,
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}();
GridGemm{}(
a_dram_grid, b_dram_grid, c_dram_grid, a_element_func, b_element_func, c_element_func);
GridGemm{}(a_dram, b_dram, c_dram, a_element_func, b_element_func, c_element_func);
}
};
......@@ -22,7 +22,7 @@ int main(int argc, char* argv[])
using AccDataType = float;
using BDataType = ck::half_t;
ck::index_t M = 3328;
ck::index_t M = 13312;
ck::index_t N = 4096;
if(argc == 3)
......
......@@ -23,37 +23,6 @@ template <typename ADataType,
ck::index_t kNPerBlock>
struct Softmax
{
#if 0
__device__ static constexpr auto MakeABlockTileDistribution()
{
using namespace ck;
using namespace ck::tile_program;
// 2x2 wave
return make_static_tile_distribution(
StaticTileDistributionEncoding<Sequence<>,
Tuple<Sequence<2, 2, 4, 2, 4>, Sequence<2, 2, 32>>,
Tuple<Sequence<1, 2>, Sequence<1, 2>>,
Tuple<Sequence<1, 1>, Sequence<3, 2>>,
Sequence<1, 2, 1, 1>,
Sequence<0, 0, 2, 4>>{});
}
#elif 0
__device__ static constexpr auto MakeABlockTileDistribution()
{
using namespace ck;
using namespace ck::tile_program;
// 2x2 wave
return make_static_tile_distribution(
StaticTileDistributionEncoding<Sequence<>,
Tuple<Sequence<2, 2, 32>, Sequence<2, 2, 4, 2, 4>>,
Tuple<Sequence<2, 1>, Sequence<2, 1>>,
Tuple<Sequence<1, 1>, Sequence<3, 2>>,
Sequence<2, 1, 2, 2>,
Sequence<0, 0, 2, 4>>{});
}
#elif 1
__device__ static constexpr auto MakeABlockTileDistribution()
{
using namespace ck;
......@@ -68,10 +37,9 @@ struct Softmax
Sequence<1, 2, 1, 1>,
Sequence<0, 0, 2, 4>>{});
}
#endif
__device__ void
operator()(const ADataType* p_a, BDataType* p_b, ck::index_t M, ck::index_t N) const
MultiPassSoftmax(const ADataType* p_a, BDataType* p_b, ck::index_t M, ck::index_t N) const
{
using namespace ck;
using namespace ck::tile_program;
......@@ -80,144 +48,226 @@ struct Softmax
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
const auto a_m_n = make_naive_tensor_view<AddressSpaceEnum::Global>(
// A DRAM tensor view
const auto a_dram = make_naive_tensor_view<AddressSpaceEnum::Global>(
p_a, make_tuple(M, N), make_tuple(N, 1), Number<32>{}, Number<1>{});
const auto iM = get_block_id() * kMPerBlock;
// A window
auto a_block_window =
make_tile_window(a_m_n,
// A DRAM window
auto a_dram_window =
make_tile_window(a_dram,
make_tuple(Number<kMPerBlock>{}, Number<kNPerBlock>{}),
{iM, 0},
MakeABlockTileDistribution());
constexpr auto reduce_dims = Sequence<1>{};
// m = rowmax(A)
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
const auto f_max = [](auto v0, auto v1) { return max(v0, v1); };
const ADataType max_reduce_init_value = NumericLimits<ADataType>::Lowest();
// max = max(a)
auto max_block_tensor = decltype(block_tile_reduce<AccDataType>(
load_tile(a_block_window), reduce_dims, f_max, max_reduce_init_value)){};
auto m = decltype(block_tile_reduce<AccDataType>(
load_tile(a_dram_window), Sequence<1>{}, f_max, ADataType{})){};
tile_elementwise_inout(
[&](auto& max) { max = type_convert<AccDataType>(max_reduce_init_value); },
max_block_tensor);
[&](auto& e) { e = type_convert<AccDataType>(NumericLimits<ADataType>::Lowest()); }, m);
index_t iN = 0;
do
{
const auto a_block_tensor = load_tile(a_block_window);
// load A tile from DRAM
const auto a = load_tile(a_dram_window);
block_tile_reduce(max_block_tensor, a_block_tensor, reduce_dims, f_max);
// m = rowmax(A)
block_tile_reduce(m, a, Sequence<1>{}, f_max);
move_tile_window(a_block_window, {0, kNPerBlock});
move_tile_window(a_dram_window, {0, kNPerBlock});
iN += kNPerBlock;
} while(iN < N);
// cross lane reduce: max
block_tile_reduce_sync(max_block_tensor, f_max);
// exp_sum = sum(exp(a - a_max))
auto exp_sum_block_tensor =
make_static_distributed_tensor<AccDataType>(max_block_tensor.GetTileDistribution());
tile_elementwise_inout([&](auto& exp_sum) { exp_sum = 0; }, exp_sum_block_tensor);
block_tile_reduce_sync(m, f_max);
// reset window location
iN = 0;
move_tile_window(a_block_window, {0, -N});
move_tile_window(a_dram_window, {0, -N});
// l = rowsum(exp(A - m))
auto l = make_static_distributed_tensor<AccDataType>(m.GetTileDistribution());
tile_elementwise_inout([&](auto& e) { e = 0; }, l);
do
{
const auto a_block_tensor = load_tile(a_block_window);
// load A tile from DRAM
const auto a = load_tile(a_dram_window);
constexpr auto a_spans = decltype(a_block_tensor)::GetDistributedSpans();
constexpr auto a_spans = decltype(a)::GetDistributedSpans();
//
sweep_tile_span(a_spans[I0], [&](auto idx0) {
constexpr auto m_idx = make_tuple(idx0);
const auto v_max = max_block_tensor[m_idx];
AccDataType v_exp_sum = exp_sum_block_tensor[m_idx];
constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(a_spans[I1], [&](auto idx1) {
constexpr auto m_n_idx = make_tuple(idx0, idx1);
const auto v_a = a_block_tensor[m_n_idx];
constexpr auto i_j_idx = make_tuple(idx0, idx1);
// exp and sum
v_exp_sum += math::exp(v_a - v_max);
// l = rowsum(exp(A - m))
l(i_idx) += math::exp(a[i_j_idx] - m[i_idx]);
});
exp_sum_block_tensor(m_idx) = v_exp_sum;
});
move_tile_window(a_block_window, {0, kNPerBlock});
move_tile_window(a_dram_window, {0, kNPerBlock});
iN += kNPerBlock;
} while(iN < N);
// cross lane reduce: sum
block_tile_reduce_sync(exp_sum_block_tensor, [](auto v0, auto v1) { return v0 + v1; });
// B
const auto b_m_n = make_naive_tensor_view<AddressSpaceEnum::Global>(
p_b, make_tuple(M, N), make_tuple(N, 1), Number<32>{}, Number<1>{});
// B window
auto b_block_window = make_tile_window(
b_m_n, make_tuple(Number<kMPerBlock>{}, Number<kNPerBlock>{}), {iM, 0});
block_tile_reduce_sync(l, [](auto e0, auto e1) { return e0 + e1; });
// reset window location
iN = 0;
move_tile_window(a_block_window, {0, -N});
move_tile_window(a_dram_window, {0, -N});
// B DRAM tensor view
const auto b_dram = make_naive_tensor_view<AddressSpaceEnum::Global>(
p_b, make_tuple(M, N), make_tuple(N, 1), Number<32>{}, Number<1>{});
// B DRAM window
auto b_dram_window = make_tile_window(
b_dram, make_tuple(Number<kMPerBlock>{}, Number<kNPerBlock>{}), {iM, 0});
// B = exp(A - m) / l
do
{
const auto a_block_tensor = load_tile(a_block_window);
// load A tile from DRAM
const auto a = load_tile(a_dram_window);
constexpr auto a_spans = decltype(a_block_tensor)::GetDistributedSpans();
constexpr auto a_spans = decltype(a)::GetDistributedSpans();
auto b_block_tensor =
make_static_distributed_tensor<BDataType>(a_block_tensor.GetTileDistribution());
auto b = make_static_distributed_tensor<BDataType>(a.GetTileDistribution());
//
sweep_tile_span(a_spans[I0], [&](auto idx0) {
constexpr auto m_idx = make_tuple(idx0);
constexpr auto i_idx = make_tuple(idx0);
const auto v_max = max_block_tensor[m_idx];
sweep_tile_span(a_spans[I1], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
// B = exp(A - m) / l
b(i_j_idx) =
type_convert<BDataType>(math::exp(a[i_j_idx] - m[i_idx]) / l[i_idx]);
});
});
// store B tile
store_tile(b_dram_window, b);
const auto v_exp_sum = exp_sum_block_tensor[m_idx];
move_tile_window(a_dram_window, {0, kNPerBlock});
move_tile_window(b_dram_window, {0, kNPerBlock});
iN += kNPerBlock;
} while(iN < N);
}
__device__ void
SinglePassSoftmax(const ADataType* p_a, BDataType* p_b, ck::index_t M, ck::index_t N) const
{
using namespace ck;
using namespace ck::tile_program;
using namespace ck::tile_program::block;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
// A DRAM tensor view
const auto a_dram = make_naive_tensor_view<AddressSpaceEnum::Global>(
p_a, make_tuple(M, N), make_tuple(N, 1), Number<32>{}, Number<1>{});
const auto iM = get_block_id() * kMPerBlock;
// A DRAM window
auto a_dram_window =
make_tile_window(a_dram,
make_tuple(Number<kMPerBlock>{}, Number<kNPerBlock>{}),
{iM, 0},
MakeABlockTileDistribution());
// f_max
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
// m = rowmax(A)
auto m = decltype(block_tile_reduce<AccDataType>(
load_tile(a_dram_window), Sequence<1>{}, f_max, ADataType{})){};
tile_elementwise_inout(
[&](auto& e) { e = type_convert<AccDataType>(NumericLimits<ADataType>::Lowest()); }, m);
// l = rowsum(exp(A - m))
auto l = make_static_distributed_tensor<AccDataType>(m.GetTileDistribution());
tile_elementwise_inout([&](auto& e) { e = 0; }, l);
// load A tile from DRAM
const auto a = load_tile(a_dram_window);
constexpr auto a_spans = decltype(a)::GetDistributedSpans();
// m = rowmax(A)
block_tile_reduce(m, a, Sequence<1>{}, f_max);
// cross lane reduce: max
block_tile_reduce_sync(m, f_max);
// l = rowsum(exp(A - m))
sweep_tile_span(a_spans[I0], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(a_spans[I1], [&](auto idx1) {
constexpr auto m_n_idx = make_tuple(idx0, idx1);
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto v_a = a_block_tensor[m_n_idx];
l(i_idx) += math::exp(a[i_j_idx] - m[i_idx]);
});
});
// exp
const BDataType v_b =
type_convert<BDataType>(math::exp(v_a - v_max) / v_exp_sum);
// cross lane reduce: sum
block_tile_reduce_sync(l, [](auto e0, auto e1) { return e0 + e1; });
b_block_tensor(m_n_idx) = v_b;
auto b = make_static_distributed_tensor<BDataType>(a.GetTileDistribution());
// B = exp(A - m) / l
sweep_tile_span(a_spans[I0], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(a_spans[I1], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
b(i_j_idx) = type_convert<BDataType>(math::exp(a[i_j_idx] - m[i_idx]) / l[i_idx]);
});
});
// store B tile
store_tile(b_block_window, b_block_tensor);
// B DRAM tensor view
const auto b_dram = make_naive_tensor_view<AddressSpaceEnum::Global>(
p_b, make_tuple(M, N), make_tuple(N, 1), Number<32>{}, Number<1>{});
move_tile_window(a_block_window, {0, kNPerBlock});
move_tile_window(b_block_window, {0, kNPerBlock});
// B DRAM window
auto b_dram_window = make_tile_window(
b_dram, make_tuple(Number<kMPerBlock>{}, Number<kNPerBlock>{}), {iM, 0});
iN += kNPerBlock;
// store B tile
store_tile(b_dram_window, b);
}
} while(iN < N);
__device__ void
operator()(const ADataType* p_a, BDataType* p_b, ck::index_t M, ck::index_t N) const
{
if(N > kNPerBlock)
{
MultiPassSoftmax(p_a, p_b, M, N);
}
else
{
SinglePassSoftmax(p_a, p_b, M, N);
}
}
};
......@@ -13,7 +13,6 @@
#include "ck/utility/tuple.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp"
namespace ck {
namespace detail {
......@@ -48,9 +47,6 @@ __host__ __device__ static auto make_desc_to_block2tile_map_adaptor(Descriptor&&
}
} // namespace detail
namespace tile_program {
namespace grid {
struct Block2TileMapNFast
{
__host__ __device__ static constexpr auto MakeBlock2TileMap(index_t NumTilesM,
......@@ -63,8 +59,7 @@ struct Block2TileMapNFast
struct Block2TileMapMFast
{
__host__ __device__ static constexpr auto MakeBlock2TileMap(index_t NumTilesM,
index_t NumTilesN)
__host__ __device__ constexpr auto operator()(index_t NumTilesM, index_t NumTilesN) const
{
const auto unmerge = make_merge_transform(make_tuple(NumTilesN, NumTilesM));
......@@ -129,43 +124,4 @@ struct Block2TileMapMAdapt
using DefaultBlock2TileMap = Block2TileMapMFast;
namespace detail {
template <typename TupleOfBaseTypes>
struct InheritFromBaseTypes;
template <typename... BaseTypes>
struct InheritFromBaseTypes<Tuple<BaseTypes...>> : remove_cvref_t<BaseTypes>...
{
};
} // namespace detail
template <index_t kBlockSize_,
index_t kMPerBlock_,
index_t kNPerBlock_,
index_t kKPerBlock_,
template <typename /* BlockGemmPipelineProblem */, typename /* BlockGemmPipelinePolicy */>
class BlockGemmPipeline_,
typename TupleOfExtraPolicies>
struct GridGemmPolicy : detail::InheritFromBaseTypes<TupleOfExtraPolicies>
{
static constexpr auto kBlockSize = kBlockSize_;
static constexpr auto kMPerBlock = kMPerBlock_;
static constexpr auto kNPerBlock = kNPerBlock_;
static constexpr auto kKPerBlock = kKPerBlock_;
template <typename GridGemmProblem>
using BlockGemmPipelineProblem =
block::BlockGemmPipelineProblem<typename GridGemmProblem::ADataType,
typename GridGemmProblem::BDataType,
typename GridGemmProblem::AccDataType,
kBlockSize,
TileGemmShape<kMPerBlock, kNPerBlock, kKPerBlock>>;
template <typename GridGemmProblem>
using BlockGemmPipeline =
BlockGemmPipeline_<BlockGemmPipelineProblem<GridGemmProblem>, GridGemmPolicy>;
};
} // namespace grid
} // namespace tile_program
} // namespace ck
......@@ -4,11 +4,9 @@
#pragma once
namespace ck {
namespace tile_program {
namespace grid {
template <typename Problem, typename Policy>
struct GridGemm
struct GridGemmV1
{
using ADataType = typename Problem::ADataType;
using BDataType = typename Problem::BDataType;
......@@ -21,8 +19,6 @@ struct GridGemm
static constexpr auto kNPerBlock = Policy::kNPerBlock;
static constexpr auto kKPerBlock = Policy::kKPerBlock;
using BlockGemmPipeline = typename Policy::template BlockGemmPipeline<Problem>;
template <typename AGridTensorView, typename BGridTensorView, typename CGridTensorView>
__device__ void operator()(const AGridTensorView& a_grid,
const BGridTensorView& b_grid,
......@@ -35,9 +31,9 @@ struct GridGemm
using namespace ck::tile_program;
using namespace ck::tile_program::block;
const auto M = a_grid.desc_.GetLength(Number<0>{});
const auto N = c_grid.desc_.GetLength(Number<1>{});
const auto K = a_grid.desc_.GetLength(Number<1>{});
const auto M = a_grid.GetTensorDescriptor().GetLength(Number<0>{});
const auto N = c_grid.GetTensorDescriptor().GetLength(Number<1>{});
const auto K = a_grid.GetTensorDescriptor().GetLength(Number<1>{});
// divide problem
const auto id_block = get_block_id();
......@@ -45,7 +41,7 @@ struct GridGemm
const auto num_tile_m = M / kMPerBlock;
const auto num_tile_n = N / kNPerBlock;
const auto block2tile = Policy::MakeBlock2TileMap(num_tile_m, num_tile_n);
const auto block2tile = Policy::template MakeBlock2TileMap<Problem>(num_tile_m, num_tile_n);
const auto id_tile = block2tile(id_block);
......@@ -61,7 +57,7 @@ struct GridGemm
b_grid, make_tuple(Number<kNPerBlock>{}, Number<kKPerBlock>{}), {iN, 0});
// Block GEMM pipeline
constexpr auto block_gemm_pipeline = BlockGemmPipeline{};
constexpr auto block_gemm_pipeline = Policy::template GetBlockGemmPipeline<Problem>();
__shared__ char p_smem_char[block_gemm_pipeline.GetStaticLdsSize()];
......@@ -85,6 +81,4 @@ struct GridGemm
}
};
} // namespace grid
} // namespace tile_program
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
namespace ck {
// Default policy for GridGemmV1
// Default policy class should not be templated, put template on member functions instead
struct GridGemmV1DefaultPolicy
{
static constexpr index_t kBlockSize = 256;
static constexpr index_t kMPerBlock = 128;
static constexpr index_t kNPerBlock = 128;
static constexpr index_t kKPerBlock = 32;
template <typename Problem>
__host__ __device__ static constexpr auto MakeBlock2TileMap(index_t NumTilesM,
index_t NumTilesN)
{
const auto unmerge = make_merge_transform(make_tuple(NumTilesN, NumTilesM));
return [unmerge](index_t block_id) {
MultiIndex<2> unmerged;
unmerge.CalculateLowerIndex(unmerged, make_multi_index(block_id));
return make_multi_index(unmerged.At<1>(), unmerged.At<0>());
};
}
template <typename Problem>
__host__ __device__ static constexpr auto GetBlockGemmPipeline()
{
using namespace ck::tile_program;
using namespace ck::tile_program::block;
using BlockGemmPipelineProblem_ =
BlockGemmPipelineProblem<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::AccDataType,
kBlockSize,
TileGemmShape<kMPerBlock, kNPerBlock, kKPerBlock>>;
return BlockGemmPipelineAGmemBGmemCRegV2<BlockGemmPipelineProblem_,
BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy>{};
}
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/shuffle_distributed_tensor_impl_in_thread.hpp"
namespace ck {
namespace tile_program {
template <typename OutTensor, typename InTensor>
__device__ void shuffle_distributed_tensor(OutTensor& out, const InTensor& in)
{
using InDataType = typename InTensor::DataType;
using OutDataType = typename OutTensor::DataType;
using InDstrEncode = typename InTensor::StaticTileDistribution::DstrEncode;
using OutDstrEncode = typename OutTensor::StaticTileDistribution::DstrEncode;
// type convert
const auto in_tmp = tile_elementwise_in(type_convert<OutDataType, InDataType>, in);
// shuffle
if constexpr(InDstrEncode::rs_lengths_ == OutDstrEncode::rs_lengths_ &&
InDstrEncode::hs_lengthss_ == OutDstrEncode::hs_lengthss_ &&
InDstrEncode::ps_to_rhss_major_ == OutDstrEncode::ps_to_rhss_major_ &&
InDstrEncode::ps_to_rhss_minor_ == OutDstrEncode::ps_to_rhss_minor_ &&
InDstrEncode::NDimY == OutDstrEncode::NDimY)
{
detail::shuffle_distributed_tensor_impl_in_thread(out, in_tmp);
}
else
{
// NOT implemented
}
}
} // namespace tile_program
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/static_distributed_tensor.hpp"
namespace ck {
namespace tile_program {
namespace detail {
template <typename OutTensor, typename InTensor>
__device__ void shuffle_distributed_tensor_impl_in_thread(OutTensor& out_tensor,
const InTensor& in_tensor)
{
constexpr auto I0 = Number<0>{};
using DataType = typename InTensor::DataType;
constexpr auto y_in_desc = InTensor::GetTileDistribution().GetYs2DDescriptor();
constexpr auto y_out_desc = OutTensor::GetTileDistribution().GetYs2DDescriptor();
// y_dim_out_to_in
constexpr auto get_rh_major_minor_to_y = [](auto dstr_tensor) {
using DstrEncode = typename decltype(dstr_tensor.GetTileDistribution())::DstrEncode;
Map<Array<index_t, 2>, index_t> rh_major_minor_to_y_;
static_for<0, DstrEncode::NDimY, 1>{}([&](auto i) {
constexpr index_t rh_major = DstrEncode::ys_to_rhs_major_[i];
constexpr index_t rh_minor = DstrEncode::ys_to_rhs_minor_[i];
rh_major_minor_to_y_({rh_major, rh_minor}) = i;
});
return rh_major_minor_to_y_;
};
constexpr auto rh_major_minor_to_y_in = get_rh_major_minor_to_y(InTensor{});
constexpr auto rh_major_minor_to_y_out = get_rh_major_minor_to_y(OutTensor{});
constexpr auto y_dim_out_to_in = [&] {
Map<index_t, index_t> y_dim_out_to_in_;
for(const auto& [rh_major_minor, y_out] : rh_major_minor_to_y_out)
{
y_dim_out_to_in_(y_out) = rh_major_minor_to_y_in[rh_major_minor];
}
return y_dim_out_to_in_;
}();
//
constexpr index_t NDimY = InTensor::GetTileDistribution().GetNumOfDimensionY();
constexpr auto y_lengths = to_sequence(y_in_desc.GetLengths());
// input and output vector dim in the order of input Y dims
constexpr index_t y_dim_vec_in = NDimY - 1;
constexpr index_t y_dim_vec_out = y_dim_out_to_in[NDimY - 1];
// vector lengths
constexpr index_t vec_length_in = y_lengths[y_dim_vec_in];
constexpr index_t vec_length_out = y_lengths[y_dim_vec_out];
// # of vectors
constexpr index_t num_vec_in = vec_length_out;
constexpr index_t num_vec_out = vec_length_in;
using InVec = vector_type<DataType, vec_length_in>;
using OutVec = vector_type<DataType, vec_length_out>;
using InVecType = typename InVec::type;
using OutVecType = typename OutVec::type;
// SFC
constexpr auto scalars_per_access_arr = generate_array(
[&](auto i) { return (i == y_dim_vec_in or i == y_dim_vec_out) ? y_lengths[i] : 1; },
Number<NDimY>{});
constexpr auto scalars_per_access = TO_SEQUENCE(scalars_per_access_arr, NDimY);
using SFC_Y = SpaceFillingCurve<decltype(y_lengths),
typename arithmetic_sequence_gen<0, NDimY, 1>::type,
decltype(scalars_per_access)>;
constexpr index_t num_access = SFC_Y::GetNumOfAccess();
static_assert(num_access > 0, "wrong! num_access should be larger than 0");
// in/out vectors to be transposed
StaticallyIndexedArray<InVec, num_vec_in> in_vectors;
StaticallyIndexedArray<OutVec, num_vec_out> out_vectors;
#if 0
print(y_dim_out_to_in);
printf("\n");
printf("y_dim_vec_in %d\n", y_dim_vec_in);
printf("y_dim_vec_out %d\n", y_dim_vec_out);
printf("num_vec_in %d\n", num_vec_in);
printf("num_vec_out %d\n", num_vec_out);
#endif
// loop over SFC and do transpose
static_for<0, num_access, 1>{}([&](auto iAccess) {
// data index [y0, y1, ...] in the order of input tensor
constexpr auto idx_y_start = SFC_Y::GetIndex(iAccess);
// get input vectors
static_for<0, num_vec_in, 1>{}([&](auto i) {
constexpr auto idx_y_in = generate_array(
[&](auto ii) {
return ii == y_dim_vec_out ? idx_y_start[ii] + i : idx_y_start[ii];
},
Number<NDimY>{});
constexpr index_t in_offset = y_in_desc.CalculateOffset(idx_y_in);
in_vectors(i).template AsType<InVecType>()(I0) =
in_tensor.GetThreadBuffer().template GetAsType<InVecType>(Number<in_offset>{});
});
// transpose
transpose_vectors<DataType, num_vec_in, num_vec_out>{}(in_vectors, out_vectors);
// set output vectors
static_for<0, num_vec_out, 1>{}([&](auto i) {
constexpr auto idx_y_out_tmp = generate_array(
[&](auto ii) { return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii]; },
Number<NDimY>{});
constexpr auto idx_y_out =
container_reorder_given_new2old(idx_y_out_tmp, y_dim_out_to_in);
constexpr index_t out_offset = y_out_desc.CalculateOffset(idx_y_out);
out_tensor.GetThreadBuffer().template SetAsType<OutVecType>(
Number<out_offset>{}, out_vectors[i].template AsType<OutVecType>()[I0]);
});
});
}
} // namespace detail
} // namespace tile_program
} // namespace ck
......@@ -349,59 +349,6 @@ make_reduce_tile_distribution_encoding(InDstr, Sequence<InReduceDimXs...> reduce
remove_cvref_t<decltype(ps_to_rhss_minor)>,
remove_cvref_t<decltype(ys_to_rhs_major)>,
remove_cvref_t<decltype(ys_to_rhs_minor)>>{};
#if 0
if(ProgramServer::get_block_id() == 0 && ProgramServer::get_thread_id() == 0)
{
printf("ndim_x: ");
print(ndim_x);
printf("\n");
printf("ndim_p: ");
print(ndim_p);
printf("\n");
printf("ndim_y: ");
print(ndim_y);
printf("\n");
printf("ndim_r: ");
print(ndim_r);
printf("\n");
printf("ndims_hs_minor: ");
print(ndims_hs_minor);
printf("\n");
printf("ndims_ps_low: ");
print(ndims_ps_low);
printf("\n");
printf("rs_lengths: ");
print(rs_lengths);
printf("\n");
printf("hs_lengthss: ");
print(hs_lengthss);
printf("\n");
printf("ps_to_rhss_major: ");
print(ps_to_rhss_major);
printf("\n");
printf("ps_to_rhss_minor: ");
print(ps_to_rhss_minor);
printf("\n");
printf("ys_to_rhs_major: ");
print(ys_to_rhs_major);
printf("\n");
printf("ys_to_rhs_minor: ");
print(ys_to_rhs_minor);
printf("\n");
}
#endif
}
} // namespace detail
......
......@@ -401,7 +401,7 @@ __host__ __device__ constexpr auto
rh_major_minor_to_hidden_ids);
}
// FIXME: this is nasty. Need to find another way to hold this info
// FIXME: this is nasty. Move it inside TileDistributionEncoding::Detail
template <typename RhMajorMinor2AdaptorHiddenIdss> // Tuple<Sequence<...>, ...>
struct TileDistributionDetail
{
......
......@@ -5,8 +5,8 @@
#include <initializer_list>
#include "functional2.hpp"
#include "sequence.hpp"
#include "ck/utility/functional2.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
......@@ -113,6 +113,31 @@ struct Array<TData, 0>
__host__ __device__ void Print() const { printf("Array{size: 0, data: []}"); }
};
template <typename TData, index_t NSize>
__host__ __device__ constexpr bool operator==(const Array<TData, NSize>& a,
const Array<TData, NSize>& b)
{
bool same = true;
for(index_t i = 0; i < NSize; ++i)
{
if(a[i] != b[i])
{
same = false;
break;
}
}
return same;
}
template <typename TData, index_t NSize>
__host__ __device__ constexpr bool operator!=(const Array<TData, NSize>& a,
const Array<TData, NSize>& b)
{
return !(a == b);
}
template <typename T, typename... Xs>
__host__ __device__ constexpr auto make_array(Xs&&... xs)
{
......
......@@ -10,6 +10,11 @@
#include "ck/utility/bit_cast.hpp"
#include "ck/utility/print.hpp"
#include "ck/utility/array.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/sequence_helper.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/tuple_helper.hpp"
#include "ck/utility/map.hpp"
#include "ck/utility/container_helper.hpp"
#include "ck/utility/statically_indexed_array.hpp"
#include "ck/utility/multi_index.hpp"
......@@ -25,10 +30,6 @@
#include "ck/utility/math_v2.hpp"
#include "ck/utility/math_ext.hpp"
#include "ck/utility/number.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/sequence_helper.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/tuple_helper.hpp"
#include "ck/utility/tuple_of_sequence_to_array_of_array.hpp"
#include "ck/utility/macro_func_array_to_sequence.hpp"
#include "ck/utility/macro_func_array_of_array_to_tuple_of_sequence.hpp"
......
......@@ -3,12 +3,13 @@
#pragma once
#include "sequence.hpp"
#include "sequence_helper.hpp"
#include "array.hpp"
#include "tuple.hpp"
#include "tuple_helper.hpp"
#include "statically_indexed_array.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/sequence_helper.hpp"
#include "ck/utility/array.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/tuple_helper.hpp"
#include "ck/utility/statically_indexed_array.hpp"
#include "ck/utility/map.hpp"
namespace ck {
......@@ -36,6 +37,7 @@ __host__ __device__ constexpr auto container_push_back(const Tuple<Ts...>& a, co
return container_concat(a, make_tuple(x));
}
// reorder Array
template <typename TData, index_t NSize, index_t... IRs>
__host__ __device__ constexpr auto
container_reorder_given_new2old(const Array<TData, NSize>& old_array, Sequence<IRs...> /*new2old*/)
......@@ -55,6 +57,38 @@ container_reorder_given_old2new(const Array<TData, NSize>& old_array, Sequence<I
old_array, typename sequence_map_inverse<decltype(old2new)>::type{});
}
// reorder Array
template <typename TData, index_t NSize>
__host__ __device__ constexpr auto
container_reorder_given_new2old(const Array<TData, NSize>& old_array,
const Map<index_t, index_t>& new2old)
{
Array<TData, NSize> new_array;
for(const auto& [new_pos, old_pos] : new2old)
{
new_array(new_pos) = old_array[old_pos];
}
return new_array;
}
template <typename TData, index_t NSize>
__host__ __device__ constexpr auto
container_reorder_given_old2new(const Array<TData, NSize>& old_array,
const Map<index_t, index_t>& old2new)
{
Array<TData, NSize> new_array;
for(const auto& [old_pos, new_pos] : old2new)
{
new_array(new_pos) = old_array[old_pos];
}
return new_array;
}
// reorder Tuple
template <typename... Ts, index_t... IRs>
__host__ __device__ constexpr auto container_reorder_given_new2old(const Tuple<Ts...>& old_tuple,
Sequence<IRs...> /*new2old*/)
......@@ -74,6 +108,7 @@ __host__ __device__ constexpr auto container_reorder_given_old2new(const Tuple<T
old_tuple, typename sequence_map_inverse<decltype(old2new)>::type{});
}
// reorder Sequence
template <index_t... Is, index_t... IRs>
__host__ __device__ constexpr auto container_reorder_given_new2old(Sequence<Is...> /* old_seq */,
Sequence<IRs...> /*new2old*/)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/array.hpp"
#include "ck/utility/tuple.hpp"
namespace ck {
// naive Map
template <typename Key, typename Data, index_t MaxSize = 128>
struct Map
{
using Pair = Tuple<Key, Data>;
using Impl = Array<Pair, MaxSize>;
Impl impl_;
index_t size_;
struct Iterator
{
Impl& impl_;
index_t pos_;
__host__ __device__ constexpr Iterator(Impl& impl, index_t pos) : impl_{impl}, pos_{pos} {}
__host__ __device__ constexpr Iterator& operator++()
{
pos_++;
return *this;
}
__host__ __device__ constexpr bool operator!=(const Iterator& other) const
{
return other.pos_ != pos_;
}
__host__ __device__ constexpr Pair& operator*() { return impl_.At(pos_); }
};
struct ConstIterator
{
const Impl& impl_;
index_t pos_;
__host__ __device__ constexpr ConstIterator(const Impl& impl, index_t pos)
: impl_{impl}, pos_{pos}
{
}
__host__ __device__ constexpr ConstIterator& operator++()
{
pos_++;
return *this;
}
__host__ __device__ constexpr bool operator!=(const ConstIterator& other) const
{
return other.pos_ != pos_;
}
__host__ __device__ constexpr const Pair& operator*() const { return impl_.At(pos_); }
};
__host__ __device__ constexpr Map() : impl_{}, size_{0} {}
__host__ __device__ constexpr index_t Size() const { return size_; }
__host__ __device__ void Clear() { size_ = 0; }
__host__ __device__ constexpr index_t FindPosition(const Key& key) const
{
for(index_t i = 0; i < Size(); i++)
{
if(impl_[i].template At<0>() == key)
{
return i;
}
}
return size_;
}
__host__ __device__ constexpr ConstIterator Find(const Key& key) const
{
return ConstIterator{impl_, FindPosition(key)};
}
__host__ __device__ constexpr Iterator Find(const Key& key)
{
return Iterator{impl_, FindPosition(key)};
}
__host__ __device__ constexpr const Data& operator[](const Key& key) const
{
const auto it = Find(key);
// FIXME
assert(pos < Size());
return impl_[it.pos_].template At<1>();
}
__host__ __device__ constexpr Data& operator()(const Key& key)
{
auto it = Find(key);
// if entry not found
if(it.pos_ == Size())
{
impl_(it.pos_).template At<0>() = key;
size_++;
}
// FIXME
assert(size_ <= MaxSize);
return impl_(it.pos_).template At<1>();
}
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
__host__ __device__ constexpr ConstIterator begin() const { return ConstIterator{impl_, 0}; }
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
__host__ __device__ constexpr ConstIterator end() const { return ConstIterator{impl_, size_}; }
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
__host__ __device__ constexpr Iterator begin() { return Iterator{impl_, 0}; }
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
__host__ __device__ constexpr Iterator end() { return Iterator{impl_, size_}; }
__host__ __device__ void Print() const
{
printf("Map{size_: %d, ", size_);
//
printf("impl_: [");
//
for(const auto& [key, data] : *this)
{
printf("{key: ");
print(key);
printf(", data: ");
print(data);
printf("}, ");
}
//
printf("]");
//
printf("}");
}
};
} // namespace ck
......@@ -640,6 +640,12 @@ __host__ __device__ constexpr bool operator==(Sequence<Xs...>, Sequence<Ys...>)
return ((Xs == Ys) && ...);
}
template <index_t... Xs, index_t... Ys>
__host__ __device__ constexpr bool operator!=(Sequence<Xs...> x, Sequence<Ys...> y)
{
return !(x == y);
}
template <index_t... Xs, index_t... Ys>
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Sequence<Ys...>)
{
......
......@@ -9,6 +9,9 @@
namespace ck {
// S: scalar type
// NX: # of vector before transpose
// NY: # of vector after transpose
template <typename S,
index_t NX,
index_t NY,
......@@ -81,6 +84,32 @@ struct transpose_vectors<half_t, NX, NY>
});
});
}
// FIXME: duplicated code
__device__ void operator()(const StaticallyIndexedArray<VX, NX>& vx_tuple,
StaticallyIndexedArray<VY, NY>& vy_tuple)
{
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static_assert((NX % 2 == 0 && NY % 2 == 0), "wrong!");
// loop over 2x2 tile and transpose data from vx_tuple into vy_tuple
static_for<0, NY, 2>{}([&](auto iy) {
static_for<0, NX, 2>{}([&](auto ix) {
// reference to 2 half2_t data from vx_tuple
const auto& x_s2_0 = vx_tuple[ix].template AsType<half2_t>()[iy / I2];
const auto& x_s2_1 = vx_tuple[ix + I1].template AsType<half2_t>()[iy / I2];
// reference to 2 half2_t data from vy_tuple
auto& y_s2_0 = vy_tuple(iy).template AsType<half2_t>()(ix / I2);
auto& y_s2_1 = vy_tuple(iy + I1).template AsType<half2_t>()(ix / I2);
// transpose
transpose_fp16_2x2(x_s2_0, x_s2_1, y_s2_0, y_s2_1);
});
});
}
};
// transpose int8 4x4
......
......@@ -258,6 +258,27 @@ struct Tuple<>
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
};
template <typename... Xs>
__host__ __device__ constexpr bool operator==(const Tuple<Xs...>& a, const Tuple<Xs...>& b)
{
bool same = true;
static_for<0, sizeof...(Xs), 1>{}([&](auto i) {
if(a[i] != b[i])
{
same = false;
}
});
return same;
}
template <typename... Xs>
__host__ __device__ constexpr bool operator!=(const Tuple<Xs...>& a, const Tuple<Xs...>& b)
{
return !(a == b);
}
template <index_t I, typename TTuple>
struct tuple_element
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment