Unverified Commit 0e92deb7 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Tile program init bulk PR (#4)



Tile Program init bulk PR

---------
Co-authored-by: default avatarzjing14 <zhangjing14@gmail.com>
Co-authored-by: default avatarPo-Yen, Chen <PoYen.Chen@amd.com>
parent 0077eeb3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/tile_window.hpp"
#include "ck/tile_program/tile/static_distributed_tensor.hpp"
namespace ck {
namespace tile_program {
// detail used by tile-programming APIs(), not supposed to be used directly
namespace detail {
// "Y dimension": Y dimensions inside TileWindowWithStaticDistribution
// input:
// y_slice_origin: starting slice origin of Y dimension
// y_slice_lengths: slice lengths of Y dimensionr
// output:
// A StaticBuffer holding slice of thread data, and data layout is hardcoded to be in the order of
// [Y0, Y1, Y2, ...]
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename YIndex,
index_t... YSliceLengths>
__device__ auto load_sliced_thread_data_from_tile_window(
TileWindowWithStaticDistribution<BottomTensorView_, WindowLengths_, TileDistribution_>&
tile_window,
const YIndex& ys_slice_origin,
Sequence<YSliceLengths...>)
{
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
using BottomTensorView = remove_cvref_t<BottomTensorView_>;
using WindowLengths = remove_cvref_t<WindowLengths_>;
using TileDstr = remove_cvref_t<TileDistribution_>;
using TileWindow = TileWindowWithStaticDistribution<BottomTensorView, WindowLengths, TileDstr>;
constexpr auto tile_dstr = TileDstr{};
constexpr index_t NDimP = TileDstr::GetNumOfDimensionP();
constexpr index_t NDimY = TileDstr::GetNumOfDimensionY();
static_assert(NDimY == YIndex::Size() && NDimY == sizeof...(YSliceLengths),
"wrong! inconsistent # of dimension");
static_assert(TileWindow::HasStaticTileDistribution(),
"wrong! assume static tile distribution");
constexpr auto y_slice_lengths = Sequence<YSliceLengths...>{};
constexpr index_t thread_element_size =
container_reduce(y_slice_lengths, math::multiplies{}, 1);
StaticBuffer<AddressSpaceEnum::Vgpr, DataType, thread_element_size, true> thread_buf;
constexpr auto tmp = [&y_slice_lengths]() {
const auto [ys_vector_lengths, ys_vector_strides] =
TileWindow::GetWindowAdaptorYsSafeVectorLengthStrides();
index_t VectorDimY = 0;
index_t ScalarPerVector = 1;
for(index_t i = 0; i < NDimY; ++i)
{
if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector)
{
ScalarPerVector = math::gcd(ys_vector_lengths[i], y_slice_lengths[i]);
VectorDimY = i;
}
}
return make_tuple(VectorDimY, ScalarPerVector);
}();
constexpr index_t VectorDimY = tmp.template At<0>();
constexpr index_t ScalarPerVector = tmp.template At<1>();
// FIXME
using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
constexpr auto scalars_per_access_arr = generate_array(
[&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, Number<NDimY>{});
constexpr auto scalars_per_access = TO_SEQUENCE(scalars_per_access_arr, NDimY);
using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
using vector_t = typename vector_type_t::type;
using SFC_Ys =
SpaceFillingCurve<decltype(y_slice_lengths), DimAccessOrder, decltype(scalars_per_access)>;
constexpr index_t num_access = SFC_Ys::GetNumOfAccess();
static_assert(num_access > 0, "wrong! num_access should be larger than 0");
// move to slice origin
const auto ps_ys_slice_origin = container_concat(Array<index_t, NDimP>{0}, ys_slice_origin);
tile_window.MoveWindowAdaptorAndBottomTensorThreadCoordinate(ps_ys_slice_origin);
// loop over thread tensor space [y0, y1, ...]
static_for<0, num_access, 1>{}([&](auto iAccess) {
// read from bottom tensor
const vector_t vec_value =
tile_window.GetBottomTensorView().template GetVectorizedElements<vector_t>(
tile_window.GetBottomTensorThreadCoordinate());
const vector_type_t vec{vec_value};
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::GetIndex(iAccess);
// write into distributed tensor
static_for<0, ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
[&](auto jj) {
return jj == VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
},
Number<NDimY>{});
constexpr index_t d = tile_dstr.GetYs2DDescriptor().CalculateOffset(idx_ys);
thread_buf.template At<d>() = vec.template AsType<DataType>()[j];
});
// move thread coordinate
if constexpr(iAccess.value != num_access - 1)
{
constexpr auto idx_diff_ys = SFC_Ys::GetForwardStep(iAccess);
constexpr auto idx_diff_ps_ys = container_concat(Array<index_t, NDimP>{0}, idx_diff_ys);
tile_window.MoveWindowAdaptorAndBottomTensorThreadCoordinate(idx_diff_ps_ys);
}
});
// move thread coordinate back to origin
{
constexpr auto idx_diff_ys = SFC_Ys::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
constexpr auto idx_diff_ps_ys = container_concat(Array<index_t, NDimP>{0}, idx_diff_ys);
tile_window.MoveWindowAdaptorAndBottomTensorThreadCoordinate(idx_diff_ps_ys);
}
// move back to origin
tile_window.MoveWindowAdaptorAndBottomTensorThreadCoordinate(MultiIndex<NDimP + NDimY>{0} -
ps_ys_slice_origin);
return thread_buf;
}
} // namespace detail
// FIXME: host dummy function for tile program
template <typename BottomTensorView_, typename WindowLengths_, typename TileDistribution_>
__host__ auto load_tile(
const TileWindowWithStaticDistribution<BottomTensorView_, WindowLengths_, TileDistribution_>&
tile_window)
{
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
using BottomTensorView = remove_cvref_t<BottomTensorView_>;
using WindowLengths = remove_cvref_t<WindowLengths_>;
using TileDstr = remove_cvref_t<TileDistribution_>;
using TileWindow = TileWindowWithStaticDistribution<BottomTensorView, WindowLengths, TileDstr>;
static_assert(is_known_at_compile_time<WindowLengths>::value,
"wrong! lengths should be static");
static_assert(TileWindow::HasStaticTileDistribution(), "wrong!");
return make_static_distributed_tensor<DataType>(tile_window.GetTileDistribution());
}
template <typename BottomTensorView_, typename WindowLengths_, typename TileDistribution_>
__device__ auto
load_tile(TileWindowWithStaticDistribution<BottomTensorView_, WindowLengths_, TileDistribution_>&
tile_window)
{
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
using BottomTensorView = remove_cvref_t<BottomTensorView_>;
using WindowLengths = remove_cvref_t<WindowLengths_>;
using TileDstr = remove_cvref_t<TileDistribution_>;
using TileWindow = TileWindowWithStaticDistribution<BottomTensorView, WindowLengths, TileDstr>;
static_assert(is_known_at_compile_time<WindowLengths>::value,
"wrong! lengths should be static");
static_assert(TileWindow::HasStaticTileDistribution(), "wrong!");
constexpr auto tile_dstr = TileDstr{};
constexpr index_t NDimY = tile_dstr.GetYs2DDescriptor().GetNumOfDimension();
auto dstr_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
dstr_tensor.GetThreadBuffer() = detail::load_sliced_thread_data_from_tile_window(
tile_window, MultiIndex<NDimY>{0}, to_sequence(tile_dstr.GetYs2DDescriptor().GetLengths()));
return dstr_tensor;
}
} // namespace tile_program
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
namespace ck {
namespace tile_program {
template <typename DataType_, typename StaticTileDistribution_>
struct StaticDistributedTensor
{
using DataType = remove_cvref_t<DataType_>;
using StaticTileDistribution = remove_cvref_t<StaticTileDistribution_>;
static_assert(StaticTileDistribution::IsStatic(),
"wrong! StaticTileDistribution should be known at compile tile");
using ThreadTensorDesc = remove_cvref_t<decltype(StaticTileDistribution{}.GetYs2DDescriptor())>;
static constexpr index_t kThreadElementSpaceSize = ThreadTensorDesc{}.GetElementSpaceSize();
__host__ __device__ static constexpr auto GetNumOfDimension()
{
return StaticTileDistribution::GetNumOfDimensionX();
}
__host__ __device__ static constexpr auto GetLengths()
{
return StaticTileDistribution::GetLengths();
}
__host__ __device__ static constexpr auto GetTileDistribution()
{
return StaticTileDistribution{};
}
__host__ __device__ static constexpr auto GetDistributedSpans()
{
return StaticTileDistribution::GetDistributedSpans();
}
__host__ __device__ void Initialize(const DataType& x) { thread_buf_.Initialize(x); }
__host__ __device__ constexpr const auto& GetThreadBuffer() const { return thread_buf_; }
__host__ __device__ constexpr auto& GetThreadBuffer() { return thread_buf_; }
__host__ __device__ static constexpr index_t GetThreadBufferSize()
{
return kThreadElementSpaceSize;
}
template <index_t... YSliceOrigins, index_t... YSliceLengths>
__host__ __device__ auto GetSlicedThreadData(Sequence<YSliceOrigins...>,
Sequence<YSliceLengths...>) const
{
static_assert(sizeof...(YSliceOrigins) == StaticTileDistribution::NDimY &&
sizeof...(YSliceLengths) == StaticTileDistribution::NDimY,
"wrong!");
constexpr auto sliced_thread_tensor_desc =
make_naive_tensor_descriptor_packed(make_tuple(YSliceLengths...));
StaticBuffer<AddressSpaceEnum::Vgpr,
DataType,
sliced_thread_tensor_desc.GetElementSpaceSize(),
true>
sliced_thread_data;
static_ford<Sequence<YSliceLengths...>>{}([&](auto idx) {
constexpr auto idx_ys = idx + Sequence<YSliceOrigins...>{};
sliced_thread_data(Number<sliced_thread_tensor_desc.CalculateOffset(idx)>{}) =
thread_buf_[Number<ThreadTensorDesc{}.CalculateOffset(idx_ys)>{}];
});
return sliced_thread_data;
}
template <index_t... YSliceOrigins, index_t... YSliceLengths, index_t NSlicedData>
__host__ __device__ void SetSlicedThreadData(
Sequence<YSliceOrigins...>,
Sequence<YSliceLengths...>,
const StaticBuffer<AddressSpaceEnum::Vgpr, DataType, NSlicedData, true>& sliced_thread_data)
{
static_assert(sizeof...(YSliceOrigins) == StaticTileDistribution::NDimY &&
sizeof...(YSliceLengths) == StaticTileDistribution::NDimY,
"wrong!");
constexpr auto sliced_thread_tensor_desc =
make_naive_tensor_descriptor_packed(make_tuple(YSliceLengths...));
static_ford<Sequence<YSliceLengths...>>{}([&](auto idx) {
constexpr auto idx_ys = idx + Sequence<YSliceOrigins...>{};
thread_buf_(Number<ThreadTensorDesc{}.CalculateOffset(idx_ys)>{}) =
sliced_thread_data[Number<sliced_thread_tensor_desc.CalculateOffset(idx)>{}];
});
}
template <index_t... Ys>
__host__ __device__ auto GetElementFromYsIndex(Sequence<Ys...> idx_ys) const
{
return thread_buf_[Number<ThreadTensorDesc{}.CalculateOffset(idx_ys)>{}];
}
template <index_t... Ys>
__host__ __device__ void SetElementFromYsIndex(Sequence<Ys...> idx_ys, const DataType& v)
{
thread_buf_(Number<ThreadTensorDesc{}.CalculateOffset(idx_ys)>{}) = v;
}
template <typename TileDistributedIndices>
__host__ __device__ auto GetElementFromTileDistributedIndices(TileDistributedIndices) const
{
static_assert(is_static_v<TileDistributedIndices>, "wrong!");
constexpr auto y_idx =
GetTileDistribution().GetYIndicesFromDistributedIndices(TileDistributedIndices{});
return GetElementFromYsIndex(y_idx);
}
template <typename TileDistributedIndices>
__host__ __device__ void SetElementFromTileDistributedIndices(TileDistributedIndices,
const DataType& v)
{
static_assert(is_static_v<TileDistributedIndices>, "wrong!");
constexpr auto y_idx =
GetTileDistribution().GetYIndicesFromDistributedIndices(TileDistributedIndices{});
return SetElementFromYsIndex(y_idx, v);
}
//
StaticBuffer<AddressSpaceEnum::Vgpr, DataType, kThreadElementSpaceSize, true> thread_buf_;
};
template <typename DataType, typename StaticTileDistribution>
__host__ __device__ constexpr auto make_static_distributed_tensor(const StaticTileDistribution&)
{
return StaticDistributedTensor<remove_cvref_t<DataType>,
remove_cvref_t<StaticTileDistribution>>{};
}
} // namespace tile_program
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
namespace ck {
namespace tile_program {
template <typename RsLengths_, // Sequence<...>
typename HsLengthss_, // Tuple<Sequence<...>, ...>
typename Ps2RHssMajor_, // Tuple<Sequence<...>, ...>
typename Ps2RHssMinor_, // Tuple<Sequence<...>, ...>
typename Ys2RHsMajor_, // Sequence<...>
typename Ys2RHsMinor_> // Sequence<...>
struct StaticTileDistributionEncoding
{
using RsLengths = remove_cvref_t<RsLengths_>;
using HsLengthss = remove_cvref_t<HsLengthss_>;
using Ps2RHssMajor = remove_cvref_t<Ps2RHssMajor_>;
using Ps2RHssMinor = remove_cvref_t<Ps2RHssMinor_>;
using Ys2RHsMajor = remove_cvref_t<Ys2RHsMajor_>;
using Ys2RHsMinor = remove_cvref_t<Ys2RHsMinor_>;
static_assert(Ps2RHssMajor::Size() == Ps2RHssMinor::Size(), "wrong!");
static_assert(Ys2RHsMajor::Size() == Ys2RHsMinor::Size(), "wrong!");
static constexpr index_t NDimX = HsLengthss::Size();
static constexpr index_t NDimP = Ps2RHssMajor::Size();
static constexpr index_t NDimY = Ys2RHsMajor::Size();
static constexpr index_t NDimR = RsLengths::Size();
// FIXME: move into Detail
static constexpr auto rs_lengths_ = RsLengths{};
static constexpr auto hs_lengthss_ = HsLengthss{};
static constexpr auto ps_to_rhss_major_ = Ps2RHssMajor{};
static constexpr auto ps_to_rhss_minor_ = Ps2RHssMinor{};
static constexpr auto ys_to_rhs_major_ = Ys2RHsMajor{};
static constexpr auto ys_to_rhs_minor_ = Ys2RHsMinor{};
// redundant but useful info
struct Detail
{
// ndim_rh_major_, ndim_span_mainor_
static constexpr index_t ndim_rh_major_ = NDimX + 1;
static constexpr index_t ndim_span_major_ = NDimX;
// ndims_rhs_minor_[ndim_rh_major_]
static constexpr auto ndims_rhs_minor_ = generate_array(
[](auto i) {
if constexpr(i.value == 0)
{
return rs_lengths_.Size();
}
else
{
return hs_lengthss_[i - Number<1>{}].Size();
}
},
Number<ndim_rh_major_>{});
// max_ndim_rh_minor_
static constexpr index_t max_ndim_rh_minor_ =
container_reduce(ndims_rhs_minor_, math::maximize<index_t>{}, 0);
// rhs_lengthss_[ndim_rh_major_][max_ndim_rh_minor_]
static constexpr auto rhs_lengthss_ =
to_array_of_array(container_concat(make_tuple(rs_lengths_), hs_lengthss_));
// rhs_major_minor_to_ys_[ndim_rh_majpr_][max_ndim_rh_minor_]
static constexpr auto rhs_major_minor_to_ys_ = [] {
Array<Array<index_t, max_ndim_rh_minor_>, NDimX + 1> rhs_major_minor_to_ys_tmp{{-1}};
static_for<0, NDimY, 1>{}([&](auto i) {
constexpr index_t rh_major = ys_to_rhs_major_[i];
constexpr index_t rh_minor = ys_to_rhs_minor_[i];
rhs_major_minor_to_ys_tmp(rh_major)(rh_minor) = i;
});
return rhs_major_minor_to_ys_tmp;
}();
// ndims_span_minor_[NDimY]
static constexpr auto ndims_span_minor_ = [] {
Array<index_t, NDimX> ndims_span_minor{0};
for(index_t i = 0; i < NDimY; i++)
{
const index_t span_major = ys_to_rhs_major_[i] - 1;
ndims_span_minor(span_major)++;
}
return ndims_span_minor;
}();
// max_ndim_span_minor_
static constexpr index_t max_ndim_span_minor_ =
container_reduce(ndims_span_minor_, math::maximize<index_t>{}, 0);
// rhs_major_minor_to_span_minor_ [ndim_rh_major_][max_ndim_rh_minor_]
static constexpr auto rhs_major_minor_to_span_minor_ = [] {
Array<Array<index_t, max_ndim_rh_minor_>, ndim_rh_major_> rhs_major_minor_to_span_minor{
{-1}};
static_for<0, ndim_rh_major_, 1>{}([&](auto rh_major) {
constexpr index_t ndim_rh_minor = ndims_rhs_minor_[rh_major];
index_t cnt_ndim_span_minor = 0;
static_for<0, ndim_rh_minor, 1>{}([&](auto rh_minor) {
constexpr index_t idim_y = rhs_major_minor_to_ys_[rh_major][rh_minor];
if(idim_y >= 0)
{
rhs_major_minor_to_span_minor(rh_major)(rh_minor) = cnt_ndim_span_minor;
cnt_ndim_span_minor++;
}
});
});
return rhs_major_minor_to_span_minor;
}();
// ys_to_span_major_[NDimY]
static constexpr auto ys_to_span_major_ =
generate_array([](auto i) { return ys_to_rhs_major_[i] - 1; }, Number<NDimY>{});
// ys_to_span_minor_[NDimY]
static constexpr auto ys_to_span_minor_ = generate_array(
[](auto i) {
return rhs_major_minor_to_span_minor_[ys_to_rhs_major_[i]][ys_to_rhs_minor_[i]];
},
Number<NDimY>{});
// distributed_spans_lengthss_[ndim_span_major_][max_ndim_span_minor_]
static constexpr auto distributed_spans_lengthss_ = [] {
Array<Array<index_t, max_ndim_span_minor_>, ndim_span_major_>
distributed_spans_lengthss{{-1}};
static_for<0, NDimY, 1>{}([&](auto i) {
const index_t rh_major = ys_to_rhs_major_[i];
const index_t rh_minor = ys_to_rhs_minor_[i];
const index_t h_length = hs_lengthss_[Number<rh_major - 1>{}][rh_minor];
const index_t span_major = rh_major - 1;
const index_t span_minor = rhs_major_minor_to_span_minor_[rh_major][rh_minor];
distributed_spans_lengthss(span_major)(span_minor) = h_length;
});
return distributed_spans_lengthss;
}();
// ndims_distributed_spans_minor_[ndim_span_major_]
static constexpr auto ndims_distributed_spans_minor_ = [] {
Array<index_t, ndim_span_major_> ndims_distributed_spans_minor{0};
static_for<0, NDimY, 1>{}([&](auto i) {
const index_t span_major = ys_to_rhs_major_[i] - 1;
ndims_distributed_spans_minor(span_major)++;
});
return ndims_distributed_spans_minor;
}();
// does_p_own_r_[NDimP][NDimR]
static constexpr auto does_p_own_r_ = [] {
if constexpr(NDimR > 0)
{
Array<Array<bool, NDimR>, NDimP> does_p_own_r{{false}};
static_for<0, NDimP, 1>{}([&](auto idim_p) {
constexpr index_t ndim_low = ps_to_rhss_major_[idim_p].Size();
static_for<0, ndim_low, 1>{}([&](auto idim_low) {
constexpr index_t rh_major = ps_to_rhss_major_[idim_p][idim_low];
constexpr index_t rh_minor = ps_to_rhss_minor_[idim_p][idim_low];
if constexpr(rh_major == 0)
{
does_p_own_r(idim_p)(rh_minor) = true;
}
});
});
return does_p_own_r;
}
else
{
return Array<Array<bool, NDimR>, NDimP>{};
}
}();
// ps_over_rs_derivative_[NDimP][NDimR]
static constexpr auto ps_over_rs_derivative_ = [] {
if constexpr(NDimR > 0)
{
Array<Array<index_t, NDimR>, NDimP> ps_over_rs_derivative{{0}};
static_for<0, NDimP, 1>{}([&](auto idim_p) {
constexpr index_t ndim_low = ps_to_rhss_major_[idim_p].Size();
index_t p_over_rh_derivative = 1;
static_for<ndim_low - 1, -1, -1>{}([&](auto idim_low) {
constexpr index_t rh_major = ps_to_rhss_major_[idim_p][idim_low];
constexpr index_t rh_minor = ps_to_rhss_minor_[idim_p][idim_low];
constexpr index_t rh_length = rhs_lengthss_[rh_major][rh_minor];
if constexpr(rh_major == 0)
{
ps_over_rs_derivative(idim_p)(rh_minor) = p_over_rh_derivative;
}
p_over_rh_derivative *= rh_length;
});
});
return ps_over_rs_derivative;
}
else
{
return Array<Array<index_t, NDimR>, NDimP>{};
}
}();
__host__ __device__ void Print() const
{
printf("StaticTileDistributionEncoding::Detail{");
//
printf("ndim_rh_major_: ");
print(ndim_rh_major_);
printf(", ");
//
printf("ndim_span_major_: ");
print(ndim_span_major_);
printf(", ");
//
printf("ndims_rhs_minor_: ");
print(ndims_rhs_minor_);
printf(", ");
//
printf("ndim_rh_major_: ");
print(ndim_rh_major_);
printf(", ");
//
printf("max_ndim_rh_minor_: ");
print(max_ndim_rh_minor_);
printf(", ");
//
printf("rhs_lengthss_: ");
print(rhs_lengthss_);
printf(", ");
//
printf("rhs_major_minor_to_ys_: ");
print(rhs_major_minor_to_ys_);
printf(", ");
//
printf("ndims_span_minor_: ");
print(ndims_span_minor_);
printf(", ");
//
printf("max_ndim_span_minor_: ");
print(max_ndim_span_minor_);
printf(", ");
//
printf("ys_to_span_major_: ");
print(ys_to_span_major_);
printf(", ");
//
printf("ys_to_span_minor_: ");
print(ys_to_span_minor_);
printf(", ");
//
printf("distributed_spans_lengthss_: ");
print(distributed_spans_lengthss_);
printf(", ");
//
printf("ndims_distributed_spans_minor_: ");
print(ndims_distributed_spans_minor_);
printf(", ");
//
printf("ps_over_rs_derivative_: ");
print(ps_over_rs_derivative_);
//
printf("}");
}
};
__host__ __device__ void Print() const
{
printf("StaticTileDistributionEncoding{");
//
printf("NDimX: %d, NDimP: %d, NDimY: %d, ", NDimX, NDimP, NDimY);
//
printf("rs_lengths_: ");
print(rs_lengths_);
printf(", ");
//
printf("hs_lengthss_: ");
print(hs_lengthss_);
printf(", ");
//
printf("ps_to_rhss_major_: ");
print(ps_to_rhss_major_);
printf(", ");
//
printf("ps_to_rhss_minor_: ");
print(ps_to_rhss_minor_);
printf(", ");
//
printf("ys_to_rhs_major_: ");
print(ys_to_rhs_major_);
printf(", ");
//
printf("ys_to_rhs_minor_: ");
print(ys_to_rhs_minor_);
printf(", ");
//
printf("Detail: ");
print(Detail{});
//
printf("}");
}
};
} // namespace tile_program
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
namespace ck {
namespace tile_program {
namespace detail {
template <typename OuterDstr, typename InnerDstr>
__host__ __device__ constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
{
static_assert(OuterDstr::NDimX == InnerDstr::NDimX, "wrong!");
constexpr index_t NDimHMajor = OuterDstr::NDimX;
using RsLengths =
sequence_merge_t<typename OuterDstr::RsLengths, typename InnerDstr::RsLengths>;
constexpr auto hs_lengthss = generate_tuple(
[&](auto i) {
return merge_sequences(typename OuterDstr::HsLengthss{}[i],
typename InnerDstr::HsLengthss{}[i]);
},
Number<NDimHMajor>{});
//
constexpr auto rhs_major_2_ndim_outer_rhs_minor = [&]() {
Array<index_t, NDimHMajor + 1> rhs_major_2_ndim_outer_rhs_minor_;
// R dimension
rhs_major_2_ndim_outer_rhs_minor_(0) = OuterDstr::RsLengths::Size();
// Hs dimensions
static_for<0, NDimHMajor, 1>{}([&](auto i) {
rhs_major_2_ndim_outer_rhs_minor_(i + 1) = typename OuterDstr::HsLengthss{}[i].Size();
});
return rhs_major_2_ndim_outer_rhs_minor_;
}();
// Ps2RHssMinor
constexpr auto updated_inner_ps_2_rhss_minor = generate_tuple(
[&](auto p) {
constexpr auto inner_p_2_rhss_major = typename InnerDstr::Ps2RHssMajor{}[p];
constexpr auto inner_p_2_rhss_minor = typename InnerDstr::Ps2RHssMinor{}[p];
constexpr index_t ndim_tmp = inner_p_2_rhss_minor.Size();
constexpr auto updated_inner_p_2_rhss_minor = [&]() {
Array<index_t, ndim_tmp> updated_inner_p_2_rhss_minor_;
for(index_t i = 0; i < ndim_tmp; i++)
{
index_t rh_major = inner_p_2_rhss_major[i];
index_t ndim_outer_h_minor = rhs_major_2_ndim_outer_rhs_minor[rh_major];
updated_inner_p_2_rhss_minor_(i) = inner_p_2_rhss_minor[i] + ndim_outer_h_minor;
}
return updated_inner_p_2_rhss_minor_;
}();
return TO_SEQUENCE(updated_inner_p_2_rhss_minor, ndim_tmp);
},
Number<InnerDstr::NDimP>{});
// Ys2RHsMinor
constexpr auto updated_inner_ys_2_rhs_minor = [&]() {
constexpr auto inner_ys_2_rhs_major = typename InnerDstr::Ys2RHsMajor{};
constexpr auto inner_ys_2_rhs_minor = typename InnerDstr::Ys2RHsMinor{};
constexpr index_t ndim_tmp = inner_ys_2_rhs_minor.Size();
constexpr auto updated_inner_ys_2_rhs_minor_ = [&]() {
Array<index_t, ndim_tmp> updated_inner_ys_2_rhs_minor__;
for(index_t i = 0; i < ndim_tmp; i++)
{
index_t rh_major = inner_ys_2_rhs_major[i];
index_t ndim_outer_h_minor = rhs_major_2_ndim_outer_rhs_minor[rh_major];
updated_inner_ys_2_rhs_minor__(i) = inner_ys_2_rhs_minor[i] + ndim_outer_h_minor;
}
return updated_inner_ys_2_rhs_minor__;
}();
return TO_SEQUENCE(updated_inner_ys_2_rhs_minor_, ndim_tmp);
}();
//
constexpr auto ps_2_rhss_major =
container_concat(typename OuterDstr::Ps2RHssMajor{}, typename InnerDstr::Ps2RHssMajor{});
constexpr auto ps_2_rhss_minor =
container_concat(typename OuterDstr::Ps2RHssMinor{}, updated_inner_ps_2_rhss_minor);
//
constexpr auto ys_2_rhs_major =
merge_sequences(typename OuterDstr::Ys2RHsMajor{}, typename InnerDstr::Ys2RHsMajor{});
constexpr auto ys_2_rhs_minor =
merge_sequences(typename OuterDstr::Ys2RHsMinor{}, updated_inner_ys_2_rhs_minor);
return StaticTileDistributionEncoding<RsLengths,
remove_cvref_t<decltype(hs_lengthss)>,
remove_cvref_t<decltype(ps_2_rhss_major)>,
remove_cvref_t<decltype(ps_2_rhss_minor)>,
remove_cvref_t<decltype(ys_2_rhs_major)>,
remove_cvref_t<decltype(ys_2_rhs_minor)>>{};
}
template <typename InDstr, index_t... InReduceDimXs>
__host__ __device__ constexpr auto
make_reduce_tile_distribution_encoding_impl(InDstr, Sequence<InReduceDimXs...> reduce_dim_xs_in)
{
constexpr auto I1 = Number<1>{};
// FIXME: increase if fail
constexpr index_t max_ndim_r_out = 20;
constexpr index_t max_ndim_y_out = 20;
//
constexpr index_t ndim_p = InDstr::NDimP;
constexpr index_t ndim_x_in = InDstr::NDimX;
constexpr index_t ndim_y_in = InDstr::NDimY;
constexpr index_t ndim_rh_major_in = InDstr::NDimX + 1;
constexpr index_t ndim_x_out = ndim_x_in - sizeof...(InReduceDimXs);
constexpr index_t max_ndim_rh_minor_in = InDstr::Detail::max_ndim_rh_minor_;
// ndims_ps_low
constexpr auto ndims_ps_low = generate_array(
[&](auto i) { return InDstr::ps_to_rhss_major_[i].Size(); }, Number<ndim_p>{});
// is_rh_major_in_for_reduce
Array<bool, ndim_rh_major_in> is_rh_major_in_for_reduce{false};
for(index_t i = 0; i < reduce_dim_xs_in.Size(); i++)
{
index_t rh_major = reduce_dim_xs_in[i] + 1;
is_rh_major_in_for_reduce(rh_major) = true;
}
// is_y_in_for_reduce
Array<bool, ndim_y_in> is_y_in_for_reduce{false};
for(index_t i = 0; i < ndim_y_in; i++)
{
index_t rh_major = InDstr::ys_to_rhs_major_[i];
if(is_rh_major_in_for_reduce[rh_major])
{
is_y_in_for_reduce(i) = true;
}
}
// is_rh_minor_in_for_y_reduce
Array<Array<bool, max_ndim_rh_minor_in>, ndim_rh_major_in> is_rh_minor_in_for_y_reduce{{false}};
static_for<0, ndim_y_in, 1>{}([&](auto i) {
index_t rh_major = InDstr::ys_to_rhs_major_[i];
index_t rh_minor = InDstr::ys_to_rhs_minor_[i];
if(is_y_in_for_reduce[i])
{
is_rh_minor_in_for_y_reduce(rh_major)(rh_minor) = true;
}
});
// in2out_rh_major
Array<index_t, ndim_rh_major_in> in2out_rh_major{-1};
index_t cnt_ndim_rh_major_out = 0;
for(index_t i = 0; i < ndim_rh_major_in; i++)
{
if(is_rh_major_in_for_reduce[i])
{
in2out_rh_major(i) = 0;
}
else
{
in2out_rh_major(i) = cnt_ndim_rh_major_out;
cnt_ndim_rh_major_out++;
}
}
// rs_lengths_out, in2out_rh_minor
Array<index_t, max_ndim_r_out> rs_lengths_out{-1};
Array<Array<index_t, max_ndim_rh_minor_in>, ndim_rh_major_in> in2out_rh_minor{{-1}};
// loop over input R dim
for(index_t i = 0; i < InDstr::rs_lengths_.Size(); i++)
{
// rs_lengths_out
rs_lengths_out(i) = InDstr::rs_lengths_[i];
// in2out_rh_minor
in2out_rh_minor(0)(i) = i;
}
// loop over input H Dim
index_t cnt_ndim_r_out = InDstr::rs_lengths_.Size();
static_for<1, ndim_rh_major_in, 1>{}([&](auto rh_major_in) {
constexpr auto h_major_in = rh_major_in - I1;
constexpr index_t ndim_rh_minor_in = InDstr::hs_lengthss_[h_major_in].Size();
if(is_rh_major_in_for_reduce[rh_major_in])
{
for(index_t rh_minor_in = 0; rh_minor_in < ndim_rh_minor_in; rh_minor_in++)
{
if(not is_rh_minor_in_for_y_reduce[rh_major_in][rh_minor_in])
{
// rs_lengths_out
rs_lengths_out(cnt_ndim_r_out) = InDstr::hs_lengthss_[h_major_in][rh_minor_in];
// in2out_rh_minor
in2out_rh_minor(rh_major_in)(rh_minor_in) = cnt_ndim_r_out;
cnt_ndim_r_out++;
}
}
}
else
{
for(index_t rh_minor_in = 0; rh_minor_in < ndim_rh_minor_in; rh_minor_in++)
{
// in2out_rh_minor
in2out_rh_minor(rh_major_in)(rh_minor_in) = rh_minor_in;
}
}
});
// ndim_r_out
const index_t ndim_r_out = cnt_ndim_r_out;
// ndims_hs_minor_out, hs_lengthss_out
Array<index_t, ndim_x_out> ndims_hs_minor_out{-1};
Array<Array<index_t, max_ndim_rh_minor_in>, ndim_x_out> hs_lengthss_out{{-1}};
index_t cnt_ndim_x_out = 0;
static_for<0, ndim_x_in, 1>{}([&](auto i) {
if(not is_rh_major_in_for_reduce[i + I1])
{
// ndims_hs_minor_out
ndims_hs_minor_out(cnt_ndim_x_out) = InDstr::hs_lengthss_[i].Size();
// hs_lengthss_out
static_for<0, InDstr::hs_lengthss_[i].Size(), 1>{}(
[&](auto j) { hs_lengthss_out(cnt_ndim_x_out)(j) = InDstr::hs_lengthss_[i][j]; });
cnt_ndim_x_out++;
}
});
// ps_to_rhss_major_out, ps_to_rhss_minor_out
Array<Array<index_t, max_ndim_rh_minor_in>, ndim_p> ps_to_rhss_major_out{{-1}};
Array<Array<index_t, max_ndim_rh_minor_in>, ndim_p> ps_to_rhss_minor_out{{-1}};
static_for<0, ndim_p, 1>{}([&](auto idim_p) {
static_for<0, InDstr::ps_to_rhss_major_[idim_p].Size(), 1>{}([&](auto idim_low) {
index_t rh_major_in = InDstr::ps_to_rhss_major_[idim_p][idim_low];
index_t rh_minor_in = InDstr::ps_to_rhss_minor_[idim_p][idim_low];
ps_to_rhss_major_out(idim_p)(idim_low) = in2out_rh_major[rh_major_in];
ps_to_rhss_minor_out(idim_p)(idim_low) = in2out_rh_minor[rh_major_in][rh_minor_in];
});
});
// ys_to_rhs_major_out, ys_to_rhs_minor_out
Array<index_t, max_ndim_y_out> ys_to_rhs_major_out{-1};
Array<index_t, max_ndim_y_out> ys_to_rhs_minor_out{-1};
index_t cnt_ndim_y_out = 0;
static_for<0, ndim_y_in, 1>{}([&](auto i) {
if(not is_y_in_for_reduce[i])
{
index_t rh_major_in = InDstr::ys_to_rhs_major_[i];
index_t rh_minor_in = InDstr::ys_to_rhs_minor_[i];
ys_to_rhs_major_out(cnt_ndim_y_out) = in2out_rh_major[rh_major_in];
ys_to_rhs_minor_out(cnt_ndim_y_out) = in2out_rh_minor[rh_major_in][rh_minor_in];
cnt_ndim_y_out++;
}
});
// ndim_y_out
const index_t ndim_y_out = cnt_ndim_y_out;
//
return make_tuple(ndim_x_out,
ndim_p,
ndim_y_out,
ndim_r_out,
ndims_hs_minor_out,
ndims_ps_low,
rs_lengths_out,
hs_lengthss_out,
ps_to_rhss_major_out,
ps_to_rhss_minor_out,
ys_to_rhs_major_out,
ys_to_rhs_minor_out);
}
template <typename InDstr, index_t... InReduceDimXs>
__host__ __device__ constexpr auto
make_reduce_tile_distribution_encoding(InDstr, Sequence<InReduceDimXs...> reduce_dim_xs_in)
{
constexpr auto impl = make_reduce_tile_distribution_encoding_impl(InDstr{}, reduce_dim_xs_in);
constexpr index_t ndim_x = impl.template At<0>();
constexpr index_t ndim_p = impl.template At<1>();
constexpr index_t ndim_y = impl.template At<2>();
constexpr index_t ndim_r = impl.template At<3>();
constexpr auto ndims_hs_minor = impl.template At<4>();
constexpr auto ndims_ps_low = impl.template At<5>();
constexpr auto rs_lengths_impl = impl.template At<6>();
constexpr auto hs_lengthss_impl = impl.template At<7>();
constexpr auto ps_to_rhss_major_impl = impl.template At<8>();
constexpr auto ps_to_rhss_minor_impl = impl.template At<9>();
constexpr auto ys_to_rhs_major_impl = impl.template At<10>();
constexpr auto ys_to_rhs_minor_impl = impl.template At<11>();
constexpr auto rs_lengths = TO_SEQUENCE(rs_lengths_impl, ndim_r);
constexpr auto hs_lengthss = TO_TUPLE_OF_SEQUENCE(hs_lengthss_impl, ndim_x, ndims_hs_minor);
constexpr auto ps_to_rhss_major =
TO_TUPLE_OF_SEQUENCE(ps_to_rhss_major_impl, ndim_p, ndims_ps_low);
constexpr auto ps_to_rhss_minor =
TO_TUPLE_OF_SEQUENCE(ps_to_rhss_minor_impl, ndim_p, ndims_ps_low);
constexpr auto ys_to_rhs_major = TO_SEQUENCE(ys_to_rhs_major_impl, ndim_y);
constexpr auto ys_to_rhs_minor = TO_SEQUENCE(ys_to_rhs_minor_impl, ndim_y);
return StaticTileDistributionEncoding<remove_cvref_t<decltype(rs_lengths)>,
remove_cvref_t<decltype(hs_lengthss)>,
remove_cvref_t<decltype(ps_to_rhss_major)>,
remove_cvref_t<decltype(ps_to_rhss_minor)>,
remove_cvref_t<decltype(ys_to_rhs_major)>,
remove_cvref_t<decltype(ys_to_rhs_minor)>>{};
#if 0
if(ProgramServer::get_block_id() == 0 && ProgramServer::get_thread_id() == 0)
{
printf("ndim_x: ");
print(ndim_x);
printf("\n");
printf("ndim_p: ");
print(ndim_p);
printf("\n");
printf("ndim_y: ");
print(ndim_y);
printf("\n");
printf("ndim_r: ");
print(ndim_r);
printf("\n");
printf("ndims_hs_minor: ");
print(ndims_hs_minor);
printf("\n");
printf("ndims_ps_low: ");
print(ndims_ps_low);
printf("\n");
printf("rs_lengths: ");
print(rs_lengths);
printf("\n");
printf("hs_lengthss: ");
print(hs_lengthss);
printf("\n");
printf("ps_to_rhss_major: ");
print(ps_to_rhss_major);
printf("\n");
printf("ps_to_rhss_minor: ");
print(ps_to_rhss_minor);
printf("\n");
printf("ys_to_rhs_major: ");
print(ys_to_rhs_major);
printf("\n");
printf("ys_to_rhs_minor: ");
print(ys_to_rhs_minor);
printf("\n");
}
#endif
}
} // namespace detail
} // namespace tile_program
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/store_tile_impl_static_distribution.hpp"
#include "ck/tile_program/tile/store_tile_impl_static_lengths.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
namespace ck {
namespace tile_program {
// FIXME: host dummy function for tile program
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename DataType_>
__host__ void
store_tile(TileWindowWithStaticDistribution<BottomTensorView_, WindowLengths_, TileDistribution_>&,
const StaticDistributedTensor<DataType_, TileDistribution_>&)
{
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename DataType_>
__device__ void
store_tile(TileWindowWithStaticDistribution<BottomTensorView_, WindowLengths_, TileDistribution_>&
tile_window,
const StaticDistributedTensor<DataType_, TileDistribution_>& dstr_tensor)
{
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
using BottomTensorView = remove_cvref_t<BottomTensorView_>;
using WindowLengths = remove_cvref_t<WindowLengths_>;
using TileDstr = remove_cvref_t<TileDistribution_>;
using TileWindow = TileWindowWithStaticDistribution<BottomTensorView, WindowLengths, TileDstr>;
static_assert(is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
static_assert(TileWindow::HasStaticTileDistribution(), "wrong!");
constexpr auto tile_dstr = TileDstr{};
constexpr auto thread_tensor_lengths_ys =
to_sequence(tile_dstr.GetYs2DDescriptor().GetLengths());
constexpr index_t NDimP = TileDstr::GetNumOfDimensionP();
constexpr index_t NDimY = TileDstr::GetNumOfDimensionY();
constexpr auto tmp = []() {
const auto [ys_vector_lengths, ys_vector_strides] =
TileWindow::GetWindowAdaptorYsSafeVectorLengthStrides();
index_t VectorDimY = 0;
index_t ScalarPerVector = 1;
for(index_t i = 0; i < NDimY; ++i)
{
if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector)
{
ScalarPerVector = ys_vector_lengths[i];
VectorDimY = i;
}
}
return make_tuple(VectorDimY, ScalarPerVector);
}();
constexpr index_t VectorDimY = tmp.template At<0>();
constexpr index_t ScalarPerVector = tmp.template At<1>();
// FIXME:
using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
constexpr auto scalars_per_access_arr = generate_array(
[&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, Number<NDimY>{});
constexpr auto scalars_per_access = TO_SEQUENCE(scalars_per_access_arr, NDimY);
using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
using vector_t = typename vector_type_t::type;
using SFC_Ys = SpaceFillingCurve<decltype(thread_tensor_lengths_ys),
DimAccessOrder,
decltype(scalars_per_access)>;
constexpr index_t num_access = SFC_Ys::GetNumOfAccess();
static_assert(num_access > 0, "wrong! num_access should be larger than 0");
// loop over thread tensor space [y0, y1, ...]
static_for<0, num_access, 1>{}([&](auto iAccess) {
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::GetIndex(iAccess);
// read from distributed tensor
vector_type_t vec;
static_for<0, ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
[&](auto jj) {
return jj == VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
},
Number<NDimY>{});
constexpr index_t d = tile_dstr.GetYs2DDescriptor().CalculateOffset(idx_ys);
vec.template AsType<DataType>()(j) = dstr_tensor.GetThreadBuffer().template At<d>();
});
const vector_t vec_value = vec.template AsType<vector_t>().template At<0>();
// write into bottom tensor
tile_window.GetBottomTensorView().template SetVectorizedElements<vector_t>(
tile_window.GetBottomTensorThreadCoordinate(), vec_value);
// move thread coordinate
if constexpr(iAccess.value != num_access - 1)
{
constexpr auto idx_diff_ys = SFC_Ys::GetForwardStep(iAccess);
constexpr auto idx_diff_ps_ys = container_concat(Array<index_t, NDimP>{0}, idx_diff_ys);
tile_window.MoveWindowAdaptorAndBottomTensorThreadCoordinate(idx_diff_ps_ys);
}
});
// move thread coordinate back to origin
{
constexpr auto idx_diff_ys = SFC_Ys::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
constexpr auto idx_diff_ps_ys = container_concat(Array<index_t, NDimP>{0}, idx_diff_ys);
tile_window.MoveWindowAdaptorAndBottomTensorThreadCoordinate(idx_diff_ps_ys);
}
}
} // namespace tile_program
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/tile_window.hpp"
namespace ck {
namespace tile_program {
// FIXME: host dummy function for tile program
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename DataType_>
__host__ void store_tile(TileWindowWithStaticLengths<BottomTensorView_, WindowLengths_>&,
const StaticDistributedTensor<DataType_, TileDistribution_>&)
{
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename DataType_>
__device__ void
store_tile(TileWindowWithStaticLengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
const StaticDistributedTensor<DataType_, TileDistribution_>& dstr_tensor)
{
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
using BottomTensorView = remove_cvref_t<BottomTensorView_>;
using WindowLengths = remove_cvref_t<WindowLengths_>;
using TileDstr = remove_cvref_t<TileDistribution_>;
using TileWindow = TileWindowWithStaticDistribution<BottomTensorView, WindowLengths, TileDstr>;
static_assert(is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
constexpr auto tile_dstr = TileDstr{};
auto tile_window = make_tile_window(tile_window_tmp.GetBottomTensorView(),
tile_window_tmp.GetWindowLengths(),
tile_window_tmp.GetWindowOrigin(),
tile_dstr);
constexpr auto thread_tensor_lengths_ys =
to_sequence(tile_dstr.GetYs2DDescriptor().GetLengths());
constexpr index_t NDimP = TileDstr::GetNumOfDimensionP();
constexpr index_t NDimY = TileDstr::GetNumOfDimensionY();
constexpr auto tmp = []() {
const auto [ys_vector_lengths, ys_vector_strides] =
TileWindow::GetWindowAdaptorYsSafeVectorLengthStrides();
index_t VectorDimY = 0;
index_t ScalarPerVector = 1;
for(index_t i = 0; i < NDimY; ++i)
{
if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector)
{
ScalarPerVector = ys_vector_lengths[i];
VectorDimY = i;
}
}
return make_tuple(VectorDimY, ScalarPerVector);
}();
constexpr index_t VectorDimY = tmp.template At<0>();
constexpr index_t ScalarPerVector = tmp.template At<1>();
// FIXME:
using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
constexpr auto scalars_per_access_arr = generate_array(
[&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, Number<NDimY>{});
constexpr auto scalars_per_access = TO_SEQUENCE(scalars_per_access_arr, NDimY);
using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
using vector_t = typename vector_type_t::type;
using SFC_Ys = SpaceFillingCurve<decltype(thread_tensor_lengths_ys),
DimAccessOrder,
decltype(scalars_per_access)>;
constexpr index_t num_access = SFC_Ys::GetNumOfAccess();
static_assert(num_access > 0, "wrong! num_access should be larger than 0");
// loop over thread tensor space [y0, y1, ...]
static_for<0, num_access, 1>{}([&](auto iAccess) {
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::GetIndex(iAccess);
// read from distributed tensor
vector_type_t vec;
static_for<0, ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
[&](auto jj) {
return jj == VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
},
Number<NDimY>{});
constexpr index_t d = tile_dstr.GetYs2DDescriptor().CalculateOffset(idx_ys);
vec.template AsType<DataType>()(j) = dstr_tensor.GetThreadBuffer().template At<d>();
});
const vector_t vec_value = vec.template AsType<vector_t>().template At<0>();
// write into bottom tensor
tile_window.GetBottomTensorView().template SetVectorizedElements<vector_t>(
tile_window.GetBottomTensorThreadCoordinate(), vec_value);
// move thread coordinate
if constexpr(iAccess.value != num_access - 1)
{
constexpr auto idx_diff_ys = SFC_Ys::GetForwardStep(iAccess);
constexpr auto idx_diff_ps_ys = container_concat(Array<index_t, NDimP>{0}, idx_diff_ys);
tile_window.MoveWindowAdaptorAndBottomTensorThreadCoordinate(idx_diff_ps_ys);
}
});
// move thread coordinate back to origin
{
constexpr auto idx_diff_ys = SFC_Ys::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
constexpr auto idx_diff_ps_ys = container_concat(Array<index_t, NDimP>{0}, idx_diff_ys);
tile_window.MoveWindowAdaptorAndBottomTensorThreadCoordinate(idx_diff_ps_ys);
}
}
} // namespace tile_program
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_description/macro_func_tensor_adaptor_from_encoding.hpp"
#include "ck/tile_program/tile/static_tile_distribution_encoding.hpp"
namespace ck {
namespace tile_program {
// distributed span
template <index_t... PartialHsLengths>
struct TileDistributedSpan
{
using Impl = Sequence<PartialHsLengths...>;
static constexpr auto impl_ = Impl{};
__host__ __device__ static constexpr bool IsStatic() { return true; }
};
// distributed index
template <index_t... PartialHsIndices>
struct TileDistributedIndex
{
using Impl = Sequence<PartialHsIndices...>;
static constexpr auto impl_ = Impl{};
__host__ __device__ static constexpr bool IsStatic() { return true; }
};
namespace detail {
template <index_t... Is>
__host__ __device__ constexpr auto make_tile_distributed_span(Sequence<Is...>)
{
return TileDistributedSpan<Is...>{};
}
template <index_t... Is>
__host__ __device__ constexpr auto make_tile_distributed_index(Sequence<Is...>)
{
return TileDistributedIndex<Is...>{};
}
} // namespace detail
template <typename PsYs2XsAdaptor_,
typename Ys2DDescriptor_,
typename StaticTileDistributionEncoding_,
typename TileDistributionDetail_> // FIXME: this is for hold ad-hoc but useful info,
// should be more elegnat
struct TileDistribution
{
using PsYs2XsAdaptor = remove_cvref_t<PsYs2XsAdaptor_>;
using Ys2DDescriptor = remove_cvref_t<Ys2DDescriptor_>;
using DstrEncode = remove_cvref_t<StaticTileDistributionEncoding_>;
using DstrDetail = remove_cvref_t<TileDistributionDetail_>;
static_assert(PsYs2XsAdaptor::IsStatic() && Ys2DDescriptor::IsStatic(),
"wrong! should be static");
static constexpr index_t NDimX = PsYs2XsAdaptor::GetNumOfBottomDimension();
static constexpr index_t NDimY = Ys2DDescriptor::GetNumOfTopDimension();
static constexpr index_t NDimP = PsYs2XsAdaptor::GetNumOfTopDimension() - NDimY;
static constexpr index_t NDimR = StaticTileDistributionEncoding_::NDimR;
PsYs2XsAdaptor ps_ys_to_xs_;
Ys2DDescriptor ys_to_d_;
__host__ __device__ static constexpr index_t GetNumOfDimensionX() { return NDimX; }
__host__ __device__ static constexpr index_t GetNumOfDimensionY() { return NDimY; }
__host__ __device__ static constexpr index_t GetNumOfDimensionP() { return NDimP; }
__host__ __device__ static constexpr index_t GetNumOfDimensionR() { return NDimR; }
__host__ __device__ static constexpr auto GetLengths()
{
#if 0
// FIXME: TensorAdaptor::GetBottomDimensionLengths is wrong. re-enable this after it's fixed
ps_ys_to_xs_.GetBottomDimensionLengths();
#else
return generate_tuple(
[&](auto i) {
constexpr index_t x_length =
container_reduce(typename DstrEncode::HsLengthss{}[i], math::multiplies{}, 1);
return Number<x_length>{};
},
Number<NDimX>{});
#endif
}
__host__ __device__ constexpr const auto& GetPsYs2XsAdaptor() const { return ps_ys_to_xs_; }
__host__ __device__ constexpr const auto& GetYs2DDescriptor() const { return ys_to_d_; }
__host__ __device__ static constexpr auto GetStaticTileDistributionEncoding()
{
return DstrEncode{};
}
#if 1
// Calculate Replication index [R0, R1, ...] based on Partion index
// FIXME: very nasty implementation
template <typename PartitionIndex>
__host__ __device__ auto CalculateRsIndexFromPsIndex(const PartitionIndex& ps_idx) const
{
static_assert(PartitionIndex::Size() == NDimP, "wrong!");
const auto ps_ys_idx = container_concat(ps_idx, Array<index_t, NDimY>{0});
const auto dummy_adaptor_coord = make_tensor_adaptor_coordinate(ps_ys_to_xs_, ps_ys_idx);
Array<index_t, NDimR> rs_idx;
static_for<0, NDimP, 1>{}([&](auto idim_p) {
constexpr index_t ndim_low = DstrEncode::ps_to_rhss_major_[idim_p].Size();
static_for<0, ndim_low, 1>{}([&](auto i) {
constexpr index_t rh_major = DstrEncode::ps_to_rhss_major_[idim_p][i];
constexpr index_t rh_minor = DstrEncode::ps_to_rhss_minor_[idim_p][i];
// 0-th rh_major is the replicate dimension
if constexpr(rh_major == 0)
{
constexpr index_t adaptor_hidden_id =
DstrDetail::rh_major_minor_to_adaptor_hidden_idss_[rh_major][rh_minor];
// fill in
rs_idx(rh_minor) = dummy_adaptor_coord.GetHiddenIndex()[adaptor_hidden_id];
}
});
});
return rs_idx;
}
#endif
__host__ __device__ static constexpr auto GetDistributedSpans()
{
constexpr auto distributed_spans_impl = DstrEncode::Detail::distributed_spans_lengthss_;
constexpr auto ndims_spans_minor = DstrEncode::Detail::ndims_distributed_spans_minor_;
return generate_tuple(
[&](auto i) {
constexpr auto span_impl = distributed_spans_impl[i];
constexpr index_t ndim_span_minor = ndims_spans_minor[i];
constexpr auto span = TO_SEQUENCE(span_impl, ndim_span_minor);
return detail::make_tile_distributed_span(span);
},
Number<NDimX>{});
}
// FIXME: it's hacky to get Y index from Distributed-Index
template <typename DistributedIndices>
__host__ __device__ static constexpr auto GetYIndicesFromDistributedIndices(DistributedIndices)
{
constexpr auto ys_idx_arr = [] {
Array<index_t, NDimY> ys_idx;
static_for<0, NDimY, 1>{}([&](auto i) {
constexpr index_t span_major = DstrEncode::Detail::ys_to_span_major_[i];
constexpr index_t span_minor = DstrEncode::Detail::ys_to_span_minor_[i];
constexpr auto dstr_index = DistributedIndices{}[Number<span_major>{}];
ys_idx(i) = dstr_index.impl_[span_minor];
});
return ys_idx;
}();
constexpr index_t ndim_y = NDimY;
return TO_SEQUENCE(ys_idx_arr, ndim_y);
}
__host__ __device__ static constexpr bool IsStatic()
{
return PsYs2XsAdaptor::IsStatic() && Ys2DDescriptor::IsStatic();
}
__host__ __device__ void Print() const
{
printf("TileDistribution{");
//
printf("StaticTileDistributionEncoding: ");
print(DstrEncode{});
printf(", ");
//
printf("ps_ys_to_xs_: ");
print(ps_ys_to_xs_);
printf(", ");
//
printf("ys_to_d_: ");
print(ys_to_d_);
//
printf("}");
}
};
namespace detail {
template <index_t NDimMax>
__host__ __device__ constexpr auto make_sequential_index(index_t ibegin, index_t iend)
{
Array<index_t, NDimMax> arr{0};
for(index_t i = 0; i < iend - ibegin; ++i)
{
arr(i) = ibegin + i;
}
return arr;
}
// this returns a constexpr encoding of TileDistribution
template <typename StaticTileDistributionEncoding_>
__host__ __device__ constexpr auto
make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_)
{
using RsLengths = typename StaticTileDistributionEncoding_::RsLengths;
using HsLengthss = typename StaticTileDistributionEncoding_::HsLengthss;
using Ps2RHssMajor = typename StaticTileDistributionEncoding_::Ps2RHssMajor;
using Ps2RHssMinor = typename StaticTileDistributionEncoding_::Ps2RHssMinor;
using Ys2RHsMajor = typename StaticTileDistributionEncoding_::Ys2RHsMajor;
using Ys2RHsMinor = typename StaticTileDistributionEncoding_::Ys2RHsMinor;
// FIXME: increase max value if fail
constexpr index_t kMaxNumTransforms = 20;
constexpr index_t kMaxMetaDataSize = 128;
constexpr index_t kMaxNumDim = 10;
using Name = IndexTransformEnum;
using MetaData = MetaDataBuffer<kMaxMetaDataSize>;
using NumDim = index_t;
using Dims = Array<index_t, kMaxNumDim>;
using Lengths = Array<index_t, kMaxNumDim>;
// Tile Adaptor
// bottom dims [x0, x1, x2, ...]
// top dims [p0, p1, ..., y0, y1, ...]
constexpr index_t ndim_x = HsLengthss::Size();
// Dim Ids: [idim_x_major, idim_x_minor] to [idim_hidden]
Array<Array<index_t, kMaxNumDim>, ndim_x + 1> rh_major_minor_to_hidden_ids;
Array<Array<index_t, kMaxNumDim>, ndim_x + 1> rh_major_minor_to_hidden_lengths;
auto trans = Array<Tuple<Name, MetaData, NumDim, Dims, NumDim, Dims>, kMaxNumTransforms>{};
index_t num_tran = 0;
index_t hidden_dim_cnt = ndim_x;
// this is Replicate transform
{
constexpr index_t ndim_r_minor = RsLengths::Size();
constexpr auto r_minor_lengths = RsLengths{};
trans(num_tran++) = {
IndexTransformEnum::Replicate,
MetaData{to_array<index_t, ndim_r_minor>(r_minor_lengths)},
NumDim{0},
Dims{},
NumDim{ndim_r_minor},
make_sequential_index<kMaxNumDim>(hidden_dim_cnt, hidden_dim_cnt + ndim_r_minor)};
for(index_t i = 0; i < ndim_r_minor; ++i)
{
rh_major_minor_to_hidden_ids(0)(i) = hidden_dim_cnt;
rh_major_minor_to_hidden_lengths(0)(i) = r_minor_lengths[i];
hidden_dim_cnt++;
}
};
// these are Unmerge transforms for X dimesions
static_for<0, ndim_x, 1>{}([&trans,
&num_tran,
&hidden_dim_cnt,
&rh_major_minor_to_hidden_ids,
&rh_major_minor_to_hidden_lengths](auto idim_x) {
constexpr auto h_minor_lengths = tuple_element_t<idim_x, HsLengthss>{};
constexpr index_t ndim_h_minor = h_minor_lengths.Size();
trans(num_tran++) = {
IndexTransformEnum::UnMerge,
MetaData{to_array<index_t, ndim_h_minor>(h_minor_lengths)},
NumDim{1},
Dims{idim_x},
NumDim{ndim_h_minor},
make_sequential_index<kMaxNumDim>(hidden_dim_cnt, hidden_dim_cnt + ndim_h_minor)};
for(index_t i = 0; i < ndim_h_minor; ++i)
{
rh_major_minor_to_hidden_ids(idim_x + 1)(i) = hidden_dim_cnt;
rh_major_minor_to_hidden_lengths(idim_x + 1)(i) = h_minor_lengths[i];
hidden_dim_cnt++;
}
});
// transform: P dimensions
constexpr index_t ndim_p = Ps2RHssMajor::Size();
Dims hidden_dim_id_ps;
static_for<0, ndim_p, 1>{}([&](auto iDimP) {
//
index_t hidden_dim_id_p = hidden_dim_cnt++;
hidden_dim_id_ps(iDimP) = hidden_dim_id_p;
constexpr auto p2RHsMajor = Ps2RHssMajor{}[iDimP];
constexpr auto p2RHsMinor = Ps2RHssMinor{}[iDimP];
static_assert(p2RHsMajor.Size() == p2RHsMinor.Size(), "wrong!");
constexpr index_t ndim_low = p2RHsMajor.Size();
Dims low_dims;
Lengths low_lengths;
for(index_t i = 0; i < ndim_low; ++i)
{
index_t rh_major = p2RHsMajor[i];
index_t rh_minor = p2RHsMinor[i];
low_dims(i) = rh_major_minor_to_hidden_ids[rh_major][rh_minor];
low_lengths(i) = rh_major_minor_to_hidden_lengths[rh_major][rh_minor];
}
trans(num_tran++) = {IndexTransformEnum::Merge,
MetaData{to_array<index_t, ndim_low>(low_lengths)},
NumDim{ndim_low},
low_dims,
NumDim{1},
Dims{hidden_dim_id_p}};
});
constexpr index_t ndim_bottom = ndim_x;
constexpr auto bottom_dim_ids = make_sequential_index<kMaxNumDim>(0, ndim_bottom);
constexpr auto ys_to_rhs_major = Ys2RHsMajor{};
constexpr auto ys_to_rhs_minor = Ys2RHsMinor{};
constexpr index_t ndim_y = Ys2RHsMajor::Size();
constexpr index_t ndim_top = ndim_p + ndim_y;
auto top_dim_ids = hidden_dim_id_ps;
{
for(index_t i = 0; i < ndim_y; ++i)
{
index_t rh_major = ys_to_rhs_major[i];
index_t rh_minor = ys_to_rhs_minor[i];
top_dim_ids(ndim_p + i) = rh_major_minor_to_hidden_ids[rh_major][rh_minor];
}
}
//
const auto ps_ys_to_xs_adaptor_encoding =
make_tuple(trans, num_tran, bottom_dim_ids, ndim_bottom, top_dim_ids, ndim_top);
// descriptor: [y0, y1, ...] to [d]
Lengths y_lengths;
index_t d_length = 1;
for(index_t i = 0; i < ndim_y; ++i)
{
index_t rh_major = ys_to_rhs_major[i];
index_t rh_minor = ys_to_rhs_minor[i];
index_t y_length = rh_major_minor_to_hidden_lengths[rh_major][rh_minor];
y_lengths(i) = y_length;
d_length *= y_length;
}
auto tran = make_tuple(IndexTransformEnum::UnMerge,
MetaData{to_array<index_t, ndim_y>(y_lengths)},
NumDim{1},
Dims{0},
NumDim{ndim_y},
make_sequential_index<kMaxNumDim>(1, ndim_y + 1));
const auto ys_to_d_adaptor_encoding = make_tuple(
make_tuple(tran), 1, Dims{0}, 1, make_sequential_index<kMaxNumDim>(1, ndim_y + 1), ndim_y);
return make_tuple(ps_ys_to_xs_adaptor_encoding,
ys_to_d_adaptor_encoding,
d_length,
rh_major_minor_to_hidden_ids);
}
// FIXME: this is nasty. Need to find another way to hold this info
template <typename RhMajorMinor2AdaptorHiddenIdss> // Tuple<Sequence<...>, ...>
struct TileDistributionDetail
{
static constexpr auto rh_major_minor_to_adaptor_hidden_idss_ =
to_array_of_array(RhMajorMinor2AdaptorHiddenIdss{});
};
} // namespace detail
// this returns a constexpr TileDistribution
template <typename StaticTileDistributionEncoding_>
__host__ __device__ constexpr auto make_tile_distribution(StaticTileDistributionEncoding_)
{
using DstrEncode = remove_cvref_t<StaticTileDistributionEncoding_>;
constexpr auto adaptor_impl =
detail::make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_{});
constexpr auto ps_ys_to_xs_adaptor_impl = adaptor_impl.template At<0>();
constexpr auto ys_to_d_adaptor_impl = adaptor_impl.template At<1>();
constexpr index_t d_length = adaptor_impl.template At<2>();
constexpr auto rh_major_minor_to_hidden_ids_impl = adaptor_impl.template At<3>();
constexpr auto ps_ys_to_xs_adaptor =
CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(ps_ys_to_xs_adaptor_impl);
constexpr auto ys_to_d_adaptor = CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(ys_to_d_adaptor_impl);
constexpr auto ys_to_d_descriptor =
make_tensor_descriptor_from_adaptor(ys_to_d_adaptor, d_length);
//
constexpr index_t ndim_rh_major = DstrEncode::Detail::ndim_rh_major_;
constexpr auto ndims_rhs_minor = DstrEncode::Detail::ndims_rhs_minor_;
constexpr auto rh_major_minor_to_hidden_ids =
TO_TUPLE_OF_SEQUENCE(rh_major_minor_to_hidden_ids_impl, ndim_rh_major, ndims_rhs_minor);
return TileDistribution<
remove_cvref_t<decltype(ps_ys_to_xs_adaptor)>,
remove_cvref_t<decltype(ys_to_d_descriptor)>,
remove_cvref_t<DstrEncode>,
detail::TileDistributionDetail<remove_cvref_t<decltype(rh_major_minor_to_hidden_ids)>>>{
ps_ys_to_xs_adaptor, ys_to_d_descriptor};
}
// this returns a static TileDistribution
template <typename StaticTileDistributionEncoding_>
__host__ __device__ constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
{
using DstrEncode = remove_cvref_t<StaticTileDistributionEncoding_>;
constexpr auto adaptor_impl =
detail::make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_{});
constexpr auto ps_ys_to_xs_adaptor_impl = adaptor_impl.template At<0>();
constexpr auto ys_to_d_adaptor_impl = adaptor_impl.template At<1>();
constexpr index_t d_length = adaptor_impl.template At<2>();
constexpr auto rh_major_minor_to_hidden_ids_impl = adaptor_impl.template At<3>();
constexpr auto ps_ys_to_xs_adaptor =
CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(ps_ys_to_xs_adaptor_impl);
constexpr auto ys_to_d_adaptor =
CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(ys_to_d_adaptor_impl);
constexpr auto ys_to_d_descriptor =
make_tensor_descriptor_from_adaptor(ys_to_d_adaptor, Number<d_length>{});
//
constexpr index_t ndim_rh_major = DstrEncode::Detail::ndim_rh_major_;
constexpr auto ndims_rhs_minor = DstrEncode::Detail::ndims_rhs_minor_;
constexpr auto rh_major_minor_to_hidden_ids =
TO_TUPLE_OF_SEQUENCE(rh_major_minor_to_hidden_ids_impl, ndim_rh_major, ndims_rhs_minor);
return TileDistribution<
remove_cvref_t<decltype(ps_ys_to_xs_adaptor)>,
remove_cvref_t<decltype(ys_to_d_descriptor)>,
remove_cvref_t<DstrEncode>,
detail::TileDistributionDetail<remove_cvref_t<decltype(rh_major_minor_to_hidden_ids)>>>{
ps_ys_to_xs_adaptor, ys_to_d_descriptor};
}
} // namespace tile_program
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/static_distributed_tensor.hpp"
namespace ck {
namespace tile_program {
// TODO: support tensors with different distribution
template <typename InOutElementFunc, typename... InOutDstrTensors>
__host__ __device__ void tile_elementwise_inout(const InOutElementFunc& inout_element_func,
InOutDstrTensors&... inout_dstr_tensors)
{
// TODO: make sure all distributed tensors have same lengths and distribution
// static_assert(xxx);
constexpr index_t thread_buffer_size =
type_pack_element<0, InOutDstrTensors...>::GetThreadBufferSize();
static_for<0, thread_buffer_size, 1>{}(
[&](auto i) { inout_element_func(inout_dstr_tensors.GetThreadBuffer()(i)...); });
}
template <typename InElementFunc, typename... InDstrTensors>
__host__ __device__ auto tile_elementwise_in(const InElementFunc& in_element_func,
const InDstrTensors&... in_dstr_tensors)
{
using OutDataType = decltype(in_element_func(typename InDstrTensors::DataType{}...));
// TODO: make sure all distributed tensors have same lengths and distribution
// static_assert(xxx);
constexpr auto in_tile_dstr = type_pack_element<0, InDstrTensors...>::GetTileDistribution();
constexpr index_t thread_buffer_size =
type_pack_element<0, InDstrTensors...>::GetThreadBufferSize();
auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
out_dstr_tensor.GetThreadBuffer()(i) =
in_element_func(in_dstr_tensors.GetThreadBuffer()[i]...);
});
return out_dstr_tensor;
}
} // namespace tile_program
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
namespace ck {
namespace tile_program {
template <index_t kMPerTile, index_t kNPerTile, index_t kKPerTile>
struct TileGemmShape
{
static constexpr index_t kM = kMPerTile;
static constexpr index_t kN = kNPerTile;
static constexpr index_t kK = kKPerTile;
};
} // namespace tile_program
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_description/tensor_adaptor_coordinate.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/tile_window_impl_static_distribution.hpp"
#include "ck/tile_program/tile/tile_window_impl_static_lengths.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_description/tensor_adaptor_coordinate.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
namespace ck {
namespace tile_program {
template <typename BottomTensorView_, typename WindowLengths_, typename StaticTileDistribution_>
struct TileWindowWithStaticDistribution
{
using BottomTensorView = remove_reference_t<BottomTensorView_>;
using WindowLengths = remove_cvref_t<WindowLengths_>;
using TileDstr = remove_cvref_t<StaticTileDistribution_>;
using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor;
using BottomTensorDesc = typename BottomTensorView::TensorDesc;
using DataType = typename BottomTensorView::DataType;
static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::GetNumOfTopDimension();
static constexpr index_t NDimBottomTensor = BottomTensorDesc::GetNumOfDimension();
// TODO: check WindowLengths and StaticTileDistribution are consistent
static_assert(is_known_at_compile_time<WindowLengths>::value,
"wrong! lengths should be static");
static_assert(TileDstr::IsStatic(), "wrong!");
static_assert(NDimBottomTensor == WindowAdaptor::GetNumOfBottomDimension(),
"wrong! inconsistent # of diemsnions");
using AdaptorTopIndex = Array<index_t, NDimWindowAdaptorTop>;
using BottomTensorIndex = Array<index_t, NDimBottomTensor>;
using WindowAdaptorCoord =
decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{}));
using BottomTensorCoord =
decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{}));
__host__ __device__ constexpr TileWindowWithStaticDistribution() = default;
// FIXME: host dummy constructor for tile program
__host__ constexpr TileWindowWithStaticDistribution(const BottomTensorView& bottom_tensor_view,
const WindowLengths&,
const BottomTensorIndex&,
const TileDstr&)
: bottom_tensor_view_{bottom_tensor_view},
window_lengths_{},
bottom_tensor_thread_coord_{},
tile_dstr_{},
window_adaptor_thread_coord_{}
{
}
__device__ constexpr TileWindowWithStaticDistribution(
const BottomTensorView& bottom_tensor_view,
const WindowLengths& window_lengths,
const BottomTensorIndex& window_origin,
const TileDstr& tile_distribution)
: bottom_tensor_view_{bottom_tensor_view},
window_lengths_{window_lengths},
window_origin_{window_origin},
bottom_tensor_thread_coord_{},
tile_dstr_{tile_distribution},
window_adaptor_thread_coord_{
make_tensor_adaptor_coordinate(tile_distribution.GetPsYs2XsAdaptor(),
AdaptorTopIndex{get_warp_id(), get_lane_id(), 0})}
{
BottomTensorIndex bottom_tensor_thread_origin_idx;
for(index_t i = 0; i < NDimBottomTensor; ++i)
{
bottom_tensor_thread_origin_idx(i) =
window_origin[i] + window_adaptor_thread_coord_.GetBottomIndex()[i];
}
bottom_tensor_thread_coord_ = make_tensor_coordinate(
bottom_tensor_view_.GetTensorDescriptor(), bottom_tensor_thread_origin_idx);
}
__host__ __device__ static constexpr index_t GetNumOfDimension() { return NDimBottomTensor; }
__host__ __device__ static constexpr bool HasStaticTileDistribution()
{
return TileDstr::IsStatic();
}
__host__ __device__ constexpr auto GetWindowLengths() const { return window_lengths_; }
__host__ __device__ constexpr auto GetTileDistribution() const { return tile_dstr_; }
__host__ __device__ constexpr auto GetBottomTensorView() const { return bottom_tensor_view_; }
__host__ __device__ constexpr auto GetWindowOrigin() const { return window_origin_; }
__host__ __device__ constexpr auto GetBottomTensorThreadCoordinate() const
{
return bottom_tensor_thread_coord_;
}
// move thread's window adaptor coordiante
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...]
__device__ void MoveWindowAdaptorThreadCoordinate(const AdaptorTopIndex& idx_diff_adaptor)
{
move_tensor_adaptor_coordinate(
tile_dstr_.GetPsYs2XsAdaptor(), window_adaptor_thread_coord_, idx_diff_adaptor);
}
// move thread's botom tensor coordiante
// [x0', x1', ... ] ==> [offset]
__device__ void MoveBottomTensorThreadCoordinate(const BottomTensorIndex& idx_diff_tensor)
{
move_tensor_coordinate(bottom_tensor_view_.GetTensorDescriptor(),
bottom_tensor_thread_coord_,
idx_diff_tensor);
}
// move thread's window adaptor coordinate and bottom tensor coordinate
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
__device__ void
MoveWindowAdaptorAndBottomTensorThreadCoordinate(const AdaptorTopIndex& idx_diff_adaptor_top)
{
Array<index_t, NDimBottomTensor> idx_diff_adaptor_bottom;
move_tensor_adaptor_coordinate(tile_dstr_.GetPsYs2XsAdaptor(),
window_adaptor_thread_coord_,
idx_diff_adaptor_top,
idx_diff_adaptor_bottom);
move_tensor_coordinate(bottom_tensor_view_.GetTensorDescriptor(),
bottom_tensor_thread_coord_,
idx_diff_adaptor_bottom);
}
// return vector dimension among [y0, y1, ...]
__host__ __device__ static constexpr auto GetWindowAdaptorYsSafeVectorLengthStrides()
{
// bottom tensor top dimension vector lengths and strides
const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] =
BottomTensorDesc::GetTopDimensionSafeVectorLengthStrides();
// window vector lengths/strides
const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths;
const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides;
// window adaptor [p0, p1, ..., y0, y1, ...]
Array<index_t, WindowAdaptor::GetNumOfHiddenDimension()> window_adaptor_vector_lengths{-1};
Array<index_t, WindowAdaptor::GetNumOfHiddenDimension()> window_adaptor_vector_strides{-1};
constexpr auto window_adaptor_bottom_dims = WindowAdaptor::GetBottomDimensionHiddenIds();
set_container_subset(window_adaptor_vector_lengths,
window_adaptor_bottom_dims,
window_adaptor_bottom_dim_vector_lengths);
set_container_subset(window_adaptor_vector_strides,
window_adaptor_bottom_dims,
window_adaptor_bottom_dim_vector_strides);
const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] =
WindowAdaptor{}.GetTopDimensionSafeVectorLengthStrides(window_adaptor_vector_lengths,
window_adaptor_vector_strides);
// [y0, y1, ...]
constexpr auto y_dims = typename arithmetic_sequence_gen<TileDstr::GetNumOfDimensionP(),
NDimWindowAdaptorTop,
1>::type{};
return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims),
get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims));
}
// this is the bottom tensor view
// [x0', x1', ...] ==> [offset]
BottomTensorView bottom_tensor_view_;
//
WindowLengths window_lengths_;
// origin ([x0', x1', ...]) of window on bottom tensor
BottomTensorIndex window_origin_;
// per-thread coordinate for bottom tensor
BottomTensorCoord bottom_tensor_thread_coord_;
// Tile tensor distribution, which contains:
// 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...]
// 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d]
TileDstr tile_dstr_;
// thread window coordinate
WindowAdaptorCoord window_adaptor_thread_coord_;
};
// TODO: use strategy
template <typename TensorView_, typename WindowLengths_, typename StaticTileDistribution_>
__host__ __device__ constexpr auto
make_tile_window(const TensorView_& tensor_view,
const WindowLengths_& window_lengths,
const MultiIndex<TensorView_::GetNumOfDimension()>& origin,
const StaticTileDistribution_& tile_distribution)
{
return TileWindowWithStaticDistribution<remove_cvref_t<TensorView_>,
remove_cvref_t<WindowLengths_>,
remove_cvref_t<StaticTileDistribution_>>{
tensor_view, window_lengths, origin, tile_distribution};
}
// FIXME: dummy host function for tile program
template <typename TensorView_, typename WindowLengths_, typename StaticTileDistribution_>
__host__ void move_tile_window(
TileWindowWithStaticDistribution<TensorView_, WindowLengths_, StaticTileDistribution_>&,
const MultiIndex<
TileWindowWithStaticDistribution<TensorView_, WindowLengths_, StaticTileDistribution_>::
GetNumOfDimension()>&)
{
}
template <typename TensorView_, typename WindowLengths_, typename StaticTileDistribution_>
__device__ void move_tile_window(
TileWindowWithStaticDistribution<TensorView_, WindowLengths_, StaticTileDistribution_>& window,
const MultiIndex<
TileWindowWithStaticDistribution<TensorView_, WindowLengths_, StaticTileDistribution_>::
GetNumOfDimension()>& step)
{
window.window_origin_ += step;
window.MoveBottomTensorThreadCoordinate(step);
}
} // namespace tile_program
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_description/tensor_adaptor_coordinate.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
namespace ck {
namespace tile_program {
template <typename BottomTensorView_, typename WindowLengths_>
struct TileWindowWithStaticLengths
{
using BottomTensorView = remove_reference_t<BottomTensorView_>;
using WindowLengths = remove_cvref_t<WindowLengths_>;
using BottomTensorDesc = typename BottomTensorView::TensorDesc;
using DataType = typename BottomTensorView::DataType;
static constexpr index_t NDimBottomTensor = BottomTensorDesc::GetNumOfDimension();
static_assert(is_known_at_compile_time<WindowLengths>::value,
"wrong! lengths should be static");
using BottomTensorIndex = Array<index_t, NDimBottomTensor>;
__host__ __device__ constexpr TileWindowWithStaticLengths() = default;
// FIXME: host dummy constructor for tile program
__host__ constexpr TileWindowWithStaticLengths(const BottomTensorView& bottom_tensor_view,
const WindowLengths&,
const BottomTensorIndex&)
: bottom_tensor_view_{bottom_tensor_view}, window_lengths_{}, window_origin_{}
{
}
__device__ constexpr TileWindowWithStaticLengths(const BottomTensorView& bottom_tensor_view,
const WindowLengths& window_lengths,
const BottomTensorIndex& window_origin)
: bottom_tensor_view_{bottom_tensor_view},
window_lengths_{window_lengths},
window_origin_{window_origin}
{
}
__host__ __device__ static constexpr index_t GetNumOfDimension() { return NDimBottomTensor; }
__host__ __device__ constexpr auto GetWindowLengths() const { return window_lengths_; }
__host__ __device__ constexpr auto GetBottomTensorView() const { return bottom_tensor_view_; }
__host__ __device__ constexpr auto GetWindowOrigin() const { return window_origin_; }
// this is the bottom tensor view
// [x0', x1', ...] ==> [offset]
BottomTensorView bottom_tensor_view_;
//
WindowLengths window_lengths_;
// origin ([x0', x1', ...]) of window on bottom tensor
BottomTensorIndex window_origin_;
};
template <typename TensorView_, typename WindowLengths_>
__host__ __device__ constexpr auto
make_tile_window(const TensorView_& tensor_view,
const WindowLengths_& window_lengths,
const MultiIndex<TensorView_::GetNumOfDimension()>& origin)
{
static_assert(is_known_at_compile_time<WindowLengths_>::value,
"wrong! lengths should be static");
return TileWindowWithStaticLengths<remove_cvref_t<TensorView_>, remove_cvref_t<WindowLengths_>>{
tensor_view, window_lengths, origin};
}
// FIXME: dummy host function for tile program
template <typename TensorView_, typename WindowLengths_>
__host__ void move_tile_window(
TileWindowWithStaticLengths<TensorView_, WindowLengths_>&,
const MultiIndex<
TileWindowWithStaticLengths<TensorView_, WindowLengths_>::GetNumOfDimension()>&)
{
}
template <typename TensorView_, typename WindowLengths_>
__device__ void move_tile_window(
TileWindowWithStaticLengths<TensorView_, WindowLengths_>& window,
const MultiIndex<TileWindowWithStaticLengths<TensorView_, WindowLengths_>::GetNumOfDimension()>&
step)
{
window.window_origin_ += step;
}
} // namespace tile_program
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tile_program/warp_tile/warp_gemm_impl.hpp"
#include "ck/tile_program/warp_tile/warp_gemm_attribute_mfma.hpp"
#include "ck/tile_program/warp_tile/warp_gemm_attribute_mfma_impl.hpp"
namespace ck {
namespace tile_program {
namespace warp {
using WarpGemmMfmaF16F16F32M32N32K8 =
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M32N32K8>>;
using WarpGemmMfmaF16F16F32M16N16K16 =
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M16N16K16>>;
using WarpGemmMfmaF16F16F32M32N32K16 =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplF16F16F32M32N32K8, 2>>;
using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8,
2>>;
using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16,
2>>;
} // namespace warp
} // namespace tile_program
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/warp_tile/warp_gemm_attribute_mfma_impl.hpp"
namespace ck {
namespace tile_program {
namespace warp {
template <typename WarpGemmAttributeMfmaImpl_>
struct WarpGemmAtrributeMfma
{
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
using ADataType = typename Impl::ADataType;
using BDataType = typename Impl::BDataType;
using CDataType = typename Impl::CDataType;
using AVecType = typename Impl::AVecType;
using BVecType = typename Impl::BVecType;
using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kM;
static constexpr index_t kN = Impl::kN;
static constexpr index_t kK = Impl::kK;
using AWarpDstrEncoding = StaticTileDistributionEncoding<
Sequence<>,
Tuple<Sequence<Impl::kAMLane>, Sequence<Impl::kABKLane, Impl::kABKPerLane>>,
Tuple<Sequence<2, 1>>,
Tuple<Sequence<0, 0>>,
Sequence<2>,
Sequence<1>>;
using BWarpDstrEncoding = StaticTileDistributionEncoding<
Sequence<>,
Tuple<Sequence<Impl::kBNLane>, Sequence<Impl::kABKLane, Impl::kABKPerLane>>,
Tuple<Sequence<2, 1>>,
Tuple<Sequence<0, 0>>,
Sequence<2>,
Sequence<1>>;
using CWarpDstrEncoding = StaticTileDistributionEncoding<
Sequence<>,
Tuple<Sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>,
Sequence<Impl::kCNLane>>,
Tuple<Sequence<1, 2>>,
Tuple<Sequence<1, 0>>,
Sequence<1, 1>,
Sequence<0, 2>>;
// c_vec += a_vec * b_vec
__device__ void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
{
Impl{}(c_vec, a_vec, b_vec);
}
// c_vec = a_vec * b_vec
__device__ CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
return Impl{}(a_vec, b_vec);
}
};
template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter>
struct WarpGemmAtrributeMfmaIterateK
{
static_assert(kKIter > 0, "wrong!");
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
using ADataType = typename Impl::ADataType;
using BDataType = typename Impl::BDataType;
using CDataType = typename Impl::CDataType;
using AVecType = typename vector_type_maker<typename Impl::AVecType, kKIter>::type::type;
using BVecType = typename vector_type_maker<typename Impl::BVecType, kKIter>::type::type;
using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kM;
static constexpr index_t kN = Impl::kN;
static constexpr index_t kK = Impl::kK * kKIter;
using AWarpDstrEncoding = StaticTileDistributionEncoding<
Sequence<>,
Tuple<Sequence<Impl::kAMLane>, Sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
Tuple<Sequence<2, 1>>,
Tuple<Sequence<0, 0>>,
Sequence<2>,
Sequence<1>>;
using BWarpDstrEncoding = StaticTileDistributionEncoding<
Sequence<>,
Tuple<Sequence<Impl::kBNLane>, Sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
Tuple<Sequence<2, 1>>,
Tuple<Sequence<0, 0>>,
Sequence<2>,
Sequence<1>>;
using CWarpDstrEncoding = StaticTileDistributionEncoding<
Sequence<>,
Tuple<Sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>,
Sequence<Impl::kCNLane>>,
Tuple<Sequence<1, 2>>,
Tuple<Sequence<1, 0>>,
Sequence<1, 1>,
Sequence<0, 2>>;
// c_vec += a_vec * b_vec
__device__ void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
{
const auto a_vector = typename vector_type_maker<AVecType, 1>::type{a_vec};
const auto b_vector = typename vector_type_maker<BVecType, 1>::type{b_vec};
static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
a_vector.template AsType<typename Impl::AVecType>()[iKIter],
b_vector.template AsType<typename Impl::BVecType>()[iKIter]);
});
}
// c_vec = a_vec * b_vec
__device__ CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
const auto a_vector = typename vector_type_maker<AVecType, 1>::type{a_vec};
const auto b_vector = typename vector_type_maker<BVecType, 1>::type{b_vec};
constexpr auto I0 = Number<0>{};
// c = a * b
auto c_vec = Impl{}(a_vector.template AsType<typename Impl::AVecType>()[I0],
b_vector.template AsType<typename Impl::BVecType>()[I0]);
// c += a * b
static_for<1, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
a_vector.template AsType<typename Impl::AVecType>()[iKIter],
b_vector.template AsType<typename Impl::BVecType>()[iKIter]);
});
return c_vec;
}
};
template <typename WarpGemmAttributeMfmaImpl_>
struct WarpGemmAtrributeMfmaTransposedCDistribution
{
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
using ADataType = typename Impl::BDataType;
using BDataType = typename Impl::ADataType;
using CDataType = typename Impl::CDataType;
using AVecType = typename Impl::BVecType;
using BVecType = typename Impl::AVecType;
using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kN;
static constexpr index_t kN = Impl::kM;
static constexpr index_t kK = Impl::kK;
using AWarpDstrEncoding = StaticTileDistributionEncoding<
Sequence<>,
Tuple<Sequence<Impl::kBNLane>, Sequence<Impl::kABKLane, Impl::kABKPerLane>>,
Tuple<Sequence<2, 1>>,
Tuple<Sequence<0, 0>>,
Sequence<2>,
Sequence<1>>;
using BWarpDstrEncoding = StaticTileDistributionEncoding<
Sequence<>,
Tuple<Sequence<Impl::kAMLane>, Sequence<Impl::kABKLane, Impl::kABKPerLane>>,
Tuple<Sequence<2, 1>>,
Tuple<Sequence<0, 0>>,
Sequence<2>,
Sequence<1>>;
using CWarpDstrEncoding = StaticTileDistributionEncoding<
Sequence<>,
Tuple<Sequence<Impl::kCNLane>,
Sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>>,
Tuple<Sequence<2, 1>>,
Tuple<Sequence<1, 0>>,
Sequence<2, 2>,
Sequence<0, 2>>;
// c_vec += a_vec * b_vec
__device__ void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
{
// swap A and B
Impl{}(c_vec, b_vec, a_vec);
}
// c_vec = a_vec * b_vec
__device__ CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
// swap A and B
return Impl{}(b_vec, a_vec);
}
};
template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter>
struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
{
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
// swap A and B
using ADataType = typename Impl::BDataType;
using BDataType = typename Impl::ADataType;
using CDataType = typename Impl::CDataType;
using AVecType = typename vector_type_maker<typename Impl::BVecType, kKIter>::type::type;
using BVecType = typename vector_type_maker<typename Impl::AVecType, kKIter>::type::type;
using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kN;
static constexpr index_t kN = Impl::kM;
static constexpr index_t kK = Impl::kK * kKIter;
using AWarpDstrEncoding = StaticTileDistributionEncoding<
Sequence<>,
Tuple<Sequence<Impl::kBNLane>, Sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
Tuple<Sequence<2, 1>>,
Tuple<Sequence<0, 0>>,
Sequence<2>,
Sequence<1>>;
using BWarpDstrEncoding = StaticTileDistributionEncoding<
Sequence<>,
Tuple<Sequence<Impl::kAMLane>, Sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
Tuple<Sequence<2, 1>>,
Tuple<Sequence<0, 0>>,
Sequence<2>,
Sequence<1>>;
using CWarpDstrEncoding = StaticTileDistributionEncoding<
Sequence<>,
Tuple<Sequence<Impl::kCNLane>,
Sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>>,
Tuple<Sequence<2, 1>>,
Tuple<Sequence<1, 0>>,
Sequence<2, 2>,
Sequence<0, 2>>;
// c_vec += a_vec * b_vec
__device__ void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
{
const auto a_vector = typename vector_type_maker<AVecType, 1>::type{a_vec};
const auto b_vector = typename vector_type_maker<BVecType, 1>::type{b_vec};
// swap A and B, value and type
static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
b_vector.template AsType<typename Impl::AVecType>()[iKIter],
a_vector.template AsType<typename Impl::BVecType>()[iKIter]);
});
}
// c_vec = a_vec * b_vec
__device__ CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
const auto a_vector = typename vector_type_maker<AVecType, 1>::type{a_vec};
const auto b_vector = typename vector_type_maker<BVecType, 1>::type{b_vec};
constexpr auto I0 = Number<0>{};
// swap A and B, value and type
auto c_vec = Impl{}(b_vector.template AsType<typename Impl::AVecType>()[I0],
a_vector.template AsType<typename Impl::BVecType>()[I0]);
static_for<1, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
b_vector.template AsType<typename Impl::AVecType>()[iKIter],
a_vector.template AsType<typename Impl::BVecType>()[iKIter]);
});
return c_vec;
}
};
} // namespace warp
} // namespace tile_program
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
namespace ck {
namespace tile_program {
namespace warp {
struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
{
using ADataType = half_t;
using BDataType = half_t;
using CDataType = float;
using AVecType = typename vector_type<half_t, 4>::type;
using BVecType = typename vector_type<half_t, 4>::type;
using CVecType = typename vector_type<float, 16>::type;
static constexpr index_t kM = 32;
static constexpr index_t kN = 32;
static constexpr index_t kK = 8;
static constexpr index_t kAMLane = 32;
static constexpr index_t kBNLane = 32;
static constexpr index_t kABKLane = 2;
static constexpr index_t kABKPerLane = 4;
static constexpr index_t kCMLane = 2;
static constexpr index_t kCNLane = 32;
static constexpr index_t kCM0PerLane = 4;
static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec
__device__ void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
{
c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, c_vec, 0, 0, 0);
}
// c_vec = a_vec * b_vec
__device__ CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
// FIXME: Is this correct?
return __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, CVecType{0.f}, 0, 0, 0);
}
};
struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
{
using ADataType = half_t;
using BDataType = half_t;
using CDataType = float;
using AVecType = typename vector_type<half_t, 4>::type;
using BVecType = typename vector_type<half_t, 4>::type;
using CVecType = typename vector_type<float, 4>::type;
static constexpr index_t kM = 16;
static constexpr index_t kN = 16;
static constexpr index_t kK = 16;
static constexpr index_t kAMLane = 16;
static constexpr index_t kBNLane = 16;
static constexpr index_t kABKLane = 4;
static constexpr index_t kABKPerLane = 4;
static constexpr index_t kCMLane = 4;
static constexpr index_t kCNLane = 16;
static constexpr index_t kCM0PerLane = 1;
static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec
__device__ void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
{
c_vec = __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, c_vec, 0, 0, 0);
}
// c_vec = a_vec * b_vec
__device__ CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
// FIXME: Is this correct?
return __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, CVecType{0.f}, 0, 0, 0);
}
};
} // namespace warp
} // namespace tile_program
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
namespace ck {
namespace tile_program {
namespace warp {
template <typename WarpGemmAttribute_>
struct WarpGemmImpl
{
using WarpGemmAttribute = remove_cvref_t<WarpGemmAttribute_>;
static constexpr index_t kM = WarpGemmAttribute::kM;
static constexpr index_t kN = WarpGemmAttribute::kN;
static constexpr index_t kK = WarpGemmAttribute::kK;
using ADataType = typename WarpGemmAttribute::ADataType;
using BDataType = typename WarpGemmAttribute::BDataType;
using CDataType = typename WarpGemmAttribute::CDataType;
using AWarpDstrEncoding = typename WarpGemmAttribute::AWarpDstrEncoding;
using BWarpDstrEncoding = typename WarpGemmAttribute::BWarpDstrEncoding;
using CWarpDstrEncoding = typename WarpGemmAttribute::CWarpDstrEncoding;
using AWarpDstr = remove_cvref_t<decltype(make_static_tile_distribution(AWarpDstrEncoding{}))>;
using BWarpDstr = remove_cvref_t<decltype(make_static_tile_distribution(BWarpDstrEncoding{}))>;
using CWarpDstr = remove_cvref_t<decltype(make_static_tile_distribution(CWarpDstrEncoding{}))>;
using AWarpTensor = StaticDistributedTensor<ADataType, AWarpDstr>;
using BWarpTensor = StaticDistributedTensor<BDataType, BWarpDstr>;
using CWarpTensor = StaticDistributedTensor<CDataType, CWarpDstr>;
__device__ void operator()(CWarpTensor& c, const AWarpTensor& a, const BWarpTensor& b) const
{
using AVec = typename vector_type<ADataType, AWarpTensor::GetThreadBufferSize()>::type;
using BVec = typename vector_type<BDataType, BWarpTensor::GetThreadBufferSize()>::type;
using CVec = typename vector_type<CDataType, CWarpTensor::GetThreadBufferSize()>::type;
constexpr auto I0 = Number<0>{};
const auto a_vec = a.GetThreadBuffer().template GetAsType<AVec>(I0);
const auto b_vec = b.GetThreadBuffer().template GetAsType<BVec>(I0);
auto c_vec = c.GetThreadBuffer().template GetAsType<CVec>(I0);
// c_vec += a_vec * b_vec
WarpGemmAttribute{}(c_vec, a_vec, b_vec);
c.GetThreadBuffer().template SetAsType<CVec>(I0, c_vec);
}
__device__ auto operator()(const AWarpTensor& a, const BWarpTensor& b) const
{
CWarpTensor c;
using AVec = typename vector_type<ADataType, AWarpTensor::GetThreadBufferSize()>::type;
using BVec = typename vector_type<BDataType, BWarpTensor::GetThreadBufferSize()>::type;
using CVec = typename vector_type<CDataType, CWarpTensor::GetThreadBufferSize()>::type;
constexpr auto I0 = Number<0>{};
const auto a_vec = a.GetThreadBuffer().template GetAsType<AVec>(I0);
const auto b_vec = b.GetThreadBuffer().template GetAsType<BVec>(I0);
// c_vec = a_vec * b_vec
auto c_vec = WarpGemmAttribute{}(a_vec, b_vec);
c.GetThreadBuffer().template SetAsType<CVec>(I0, c_vec);
return c;
}
};
} // namespace warp
} // namespace tile_program
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
namespace ck {
template <typename T>
__device__ T warp_shuffle_up(const T& var, uint32_t delta)
{
#if 0
return __shfl_up(var, delta);
#elif 1
const uint32_t wrap_around_delta = warpSize - delta;
return __builtin_amdgcn_ds_bpermute((__lane_id() << 2) + (wrap_around_delta << 2), var);
#endif
}
template <typename T>
__device__ T warp_shuffle_down(const T& var, uint32_t delta)
{
#if 0
return __shfl_down(var, delta);
#elif 1
return __builtin_amdgcn_ds_bpermute((__lane_id() << 2) + (delta << 2), var);
#endif
}
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_ARRAY_HPP
#define CK_ARRAY_HPP
#pragma once
#include <initializer_list>
#include "functional2.hpp"
#include "sequence.hpp"
......@@ -17,25 +18,83 @@ struct Array
TData mData[NSize];
__host__ __device__ constexpr Array() : mData{} {}
__host__ __device__ constexpr Array(std::initializer_list<TData> ilist)
{
constexpr index_t list_size = std::initializer_list<TData>{}.size();
static_assert(list_size <= NSize, "out of bound");
index_t i = 0;
TData vlast = TData{};
for(const TData& val : ilist)
{
mData[i] = val;
vlast = val;
++i;
}
for(; i < NSize; ++i)
{
mData[i] = vlast;
}
}
__host__ __device__ static constexpr index_t Size() { return NSize; }
template <index_t I>
__host__ __device__ constexpr const TData& At() const
{
return mData[I];
}
template <index_t I>
__host__ __device__ constexpr TData& At()
{
return mData[I];
}
__host__ __device__ constexpr const TData& At(index_t i) const { return mData[i]; }
__host__ __device__ constexpr TData& At(index_t i) { return mData[i]; }
__host__ __device__ constexpr const TData& operator[](index_t i) const { return At(i); }
__host__ __device__ constexpr const TData& operator[](index_t i) const { return mData[i]; }
__host__ __device__ constexpr TData& operator()(index_t i) { return At(i); }
__host__ __device__ constexpr TData& operator()(index_t i) { return mData[i]; }
template <typename T>
__host__ __device__ constexpr auto operator=(const T& a)
{
static_assert(T::Size() == Size(), "wrong! size not the same");
static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; });
for(index_t i = 0; i < NSize; ++i)
{
mData[i] = a[i];
}
return *this;
}
__host__ __device__ static constexpr bool IsStatic() { return is_static_v<TData>; };
__host__ __device__ void Print() const
{
printf("Array{size: %d, data: [", NSize);
for(index_t i = 0; i < NSize; i++)
{
print(mData[i]);
if(i < NSize - 1)
{
printf(", ");
}
}
printf("]}");
}
};
// empty Array
......@@ -45,22 +104,42 @@ struct Array<TData, 0>
using type = Array;
using data_type = TData;
__host__ __device__ constexpr Array() {}
__host__ __device__ static constexpr index_t Size() { return 0; }
__host__ __device__ static constexpr bool IsStatic() { return is_static_v<TData>; };
__host__ __device__ void Print() const { printf("Array{size: 0, data: []}"); }
};
template <typename X, typename... Xs>
__host__ __device__ constexpr auto make_array(X&& x, Xs&&... xs)
template <typename T, typename... Xs>
__host__ __device__ constexpr auto make_array(Xs&&... xs)
{
using data_type = remove_cvref_t<X>;
return Array<data_type, sizeof...(Xs) + 1>{std::forward<X>(x), std::forward<Xs>(xs)...};
using data_type = remove_cvref_t<T>;
return Array<data_type, sizeof...(Xs)>{std::forward<Xs>(xs)...};
}
template <typename F, index_t N>
__host__ __device__ constexpr auto generate_array(F&& f, Number<N>)
{
using T = remove_cvref_t<decltype(f(Number<0>{}))>;
return unpack([&f](auto&&... is) { return Array<T, N>{f(is)...}; },
typename arithmetic_sequence_gen<0, N, 1>::type{});
}
// make empty array
template <typename X>
__host__ __device__ constexpr auto make_array()
template <typename T, index_t N, typename X>
__host__ __device__ constexpr auto to_array(const X& x)
{
return Array<X, 0>{};
STATIC_ASSERT(N <= X::Size(), "");
Array<T, N> arr;
static_for<0, N, 1>{}([&x, &arr](auto i) { arr(i) = x[i]; });
return arr;
}
} // namespace ck
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_ARRAY_MULTI_INDEX_HPP
#define CK_ARRAY_MULTI_INDEX_HPP
#pragma once
#include "common_header.hpp"
......@@ -77,4 +76,3 @@ __host__ __device__ constexpr auto operator*(const MultiIndex<NSize>& a, const T
}
} // namespace ck
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment