Commit 5a9c4962 authored by Adam Osewski's avatar Adam Osewski
Browse files

Merge remote-tracking branch 'origin/develop' into aosewski/ggemm_multi_d2

parents 3970cf73 43879b89
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/tensor/tensor_view.hpp"
namespace ck_tile {
// placeholder type if we want to opt-out a tile window parameter
template <typename WindowLengths_>
struct null_tile_window
{
using BottomTensorView = null_tensor_view;
using WindowLengths = remove_cvref_t<WindowLengths_>;
using BottomTensorIndex = array<index_t, WindowLengths::size()>;
CK_TILE_DEVICE constexpr null_tile_window() = default;
CK_TILE_DEVICE constexpr null_tile_window(const WindowLengths& window_lengths)
: window_lengths_{window_lengths}
{
}
CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return null_tensor_view{}; }
CK_TILE_DEVICE constexpr auto get_window_origin() const { return BottomTensorIndex{}; }
WindowLengths window_lengths_;
};
// utility to check if this is a Null Tile Window
namespace impl {
template <typename>
struct is_null_tile_window : public std::false_type
{
};
template <typename T>
struct is_null_tile_window<null_tile_window<T>> : public std::true_type
{
};
} // namespace impl
template <typename T>
CK_TILE_DEVICE constexpr auto is_null_tile_window(const T&)
{
return impl::is_null_tile_window<remove_cvref_t<T>>::value;
}
template <typename WindowLengths>
CK_TILE_DEVICE constexpr auto make_null_tile_window(const WindowLengths& window_lengths)
{
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
"wrong! lengths should be static");
return null_tile_window<remove_cvref_t<WindowLengths>>{window_lengths};
}
template <typename WindowLengths, typename... Ts>
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view,
const WindowLengths& window_lengths,
const multi_index<WindowLengths::size()>& /*origin*/,
Ts&&...)
{
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
"wrong! lengths should be static");
return null_tile_window<remove_cvref_t<WindowLengths>>{window_lengths};
}
template <typename WindowLengths>
CK_TILE_DEVICE void
move_tile_window(null_tile_window<WindowLengths>&,
const typename null_tile_window<WindowLengths>::BottomTensorIndex&)
{
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/thread_buffer.hpp"
#include "ck_tile/core/container/statically_indexed_array.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/utility/transpose_vectors.hpp"
namespace ck_tile {
namespace detail {
template <typename OutTensor, typename InTensor>
CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InTensor& in_tensor)
{
constexpr auto I0 = number<0>{};
using DataType = typename InTensor::DataType;
constexpr auto y_in_desc = InTensor::get_tile_distribution().get_ys_to_d_descriptor();
constexpr auto y_out_desc = OutTensor::get_tile_distribution().get_ys_to_d_descriptor();
// y_dim_out_to_in
constexpr auto get_rh_major_minor_to_y = [](auto dstr_tensor) {
using DstrEncode = typename decltype(dstr_tensor.get_tile_distribution())::DstrEncode;
map<array<index_t, 2>, index_t> rh_major_minor_to_y_;
static_for<0, DstrEncode::NDimY, 1>{}([&](auto i) {
constexpr index_t rh_major = DstrEncode::ys_to_rhs_major_[i];
constexpr index_t rh_minor = DstrEncode::ys_to_rhs_minor_[i];
rh_major_minor_to_y_({rh_major, rh_minor}) = i;
});
return rh_major_minor_to_y_;
};
constexpr auto rh_major_minor_to_y_in = get_rh_major_minor_to_y(InTensor{});
constexpr auto rh_major_minor_to_y_out = get_rh_major_minor_to_y(OutTensor{});
constexpr auto y_dim_out_to_in = [&] {
map<index_t, index_t> y_dim_out_to_in_;
for(const auto& [rh_major_minor, y_out] : rh_major_minor_to_y_out)
{
y_dim_out_to_in_(y_out) = rh_major_minor_to_y_in[rh_major_minor];
}
return y_dim_out_to_in_;
}();
//
constexpr index_t NDimY = InTensor::get_tile_distribution().get_num_of_dimension_y();
constexpr auto y_lengths = to_sequence(y_in_desc.get_lengths());
// input and output vector dim in the order of input Y dims
constexpr index_t y_dim_vec_in = NDimY - 1;
constexpr index_t y_dim_vec_out = y_dim_out_to_in[NDimY - 1];
// vector lengths
constexpr index_t vec_length_in = y_lengths[y_dim_vec_in];
constexpr index_t vec_length_out = y_lengths[y_dim_vec_out];
// # of vectors
constexpr index_t num_vec_in = vec_length_out;
constexpr index_t num_vec_out = vec_length_in;
using InVec = array<DataType, vec_length_in>;
using OutVec = array<DataType, vec_length_out>;
// using InVec = typename InVec::type;
// using OutVec = typename OutVec::type;
// SFC
constexpr auto scalars_per_access_arr = generate_array(
[&](auto i) { return (i == y_dim_vec_in or i == y_dim_vec_out) ? y_lengths[i] : 1; },
number<NDimY>{});
constexpr auto scalars_per_access = TO_SEQUENCE(scalars_per_access_arr, NDimY);
using SFC_Y = space_filling_curve<decltype(y_lengths),
typename arithmetic_sequence_gen<0, NDimY, 1>::type,
decltype(scalars_per_access)>;
constexpr index_t num_access = SFC_Y::get_num_of_access();
static_assert(num_access > 0, "wrong! num_access should be larger than 0");
// in/out vectors to be transposed
thread_buffer<InVec, num_vec_in> in_vectors;
thread_buffer<OutVec, num_vec_out> out_vectors;
// loop over SFC and do transpose
static_for<0, num_access, 1>{}([&](auto iAccess) {
// data index [y0, y1, ...] in the order of input tensor
constexpr auto idx_y_start = SFC_Y::get_index(iAccess);
// get input vectors
static_for<0, num_vec_in, 1>{}([&](auto i) {
constexpr auto idx_y_in = generate_array(
[&](auto ii) {
return ii == y_dim_vec_out ? idx_y_start[ii] + i : idx_y_start[ii];
},
number<NDimY>{});
constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in);
static_assert(in_offset % vec_length_in == 0);
in_vectors(i).template get_as<InVec>()(I0) =
in_tensor.get_thread_buffer()
.template get_as<InVec>()[number<in_offset / vec_length_in>{}];
});
// transpose
transpose_vectors<DataType, num_vec_in, num_vec_out>{}(in_vectors, out_vectors);
// set output vectors
static_for<0, num_vec_out, 1>{}([&](auto i) {
constexpr auto idx_y_out_tmp = generate_array(
[&](auto ii) { return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii]; },
number<NDimY>{});
constexpr auto idx_y_out =
container_reorder_given_new2old(idx_y_out_tmp, y_dim_out_to_in);
constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y_out);
static_assert(out_offset % vec_length_out == 0);
out_tensor.get_thread_buffer().template set_as<OutVec>(
number<out_offset / vec_length_out>{},
out_vectors[i].template get_as<OutVec>()[I0]);
});
});
}
} // namespace detail
template <typename OutTensor, typename InTensor>
CK_TILE_DEVICE void shuffle_tile(OutTensor& out, const InTensor& in)
{
using InDataType = typename InTensor::DataType;
using OutDataType = typename OutTensor::DataType;
using InDstrEncode = typename InTensor::StaticTileDistribution::DstrEncode;
using OutDstrEncode = typename OutTensor::StaticTileDistribution::DstrEncode;
// type convert
const auto in_tmp = tile_elementwise_in(type_convert<OutDataType, InDataType>, in);
// shuffle
if constexpr(InDstrEncode::rs_lengths_ == OutDstrEncode::rs_lengths_ &&
InDstrEncode::hs_lengthss_ == OutDstrEncode::hs_lengthss_ &&
InDstrEncode::ps_to_rhss_major_ == OutDstrEncode::ps_to_rhss_major_ &&
InDstrEncode::ps_to_rhss_minor_ == OutDstrEncode::ps_to_rhss_minor_ &&
InDstrEncode::NDimY == OutDstrEncode::NDimY)
{
detail::shuffle_tile_impl_in_thread(out, in_tmp);
}
else
{
// NOT implemented
}
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
template <typename BottomTensorView_,
typename WindowLengths_,
index_t... SliceBegins,
index_t... SliceEnds>
CK_TILE_DEVICE constexpr auto
get_slice_tile(const tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile,
sequence<SliceBegins...> slice_begins,
sequence<SliceEnds...> slice_ends)
{
using TileWindow = tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>;
// NOTE: This API will override the origin of the tile window!
static_assert(sizeof...(SliceBegins) == sizeof...(SliceEnds));
static_assert(sizeof...(SliceBegins) == TileWindow::get_num_of_dimension());
constexpr auto slice_lengths = slice_ends - slice_begins;
return make_tile_window(tile.get_bottom_tensor_view(),
sequence_to_tuple_of_number(slice_lengths),
to_multi_index(slice_begins));
}
template <typename DataType_,
typename StaticTileDistribution_,
index_t... SliceBegins,
index_t... SliceEnds>
CK_TILE_DEVICE constexpr auto
get_slice_tile(const static_distributed_tensor<DataType_, StaticTileDistribution_>& tile,
sequence<SliceBegins...> slice_begins,
sequence<SliceEnds...> slice_ends)
{
using DataType = remove_cvref_t<DataType_>;
using Distribution = remove_cvref_t<StaticTileDistribution_>;
constexpr auto sliced_dstr_yidx_ylen =
detail::slice_distribution_from_x(Distribution{}, slice_begins, slice_ends);
constexpr auto sliced_dstr = sliced_dstr_yidx_ylen.template at<0>();
constexpr auto sliced_y_origins = sliced_dstr_yidx_ylen.template at<1>();
constexpr auto sliced_y_lengths = sliced_dstr_yidx_ylen.template at<2>();
auto sliced_tensor = make_static_distributed_tensor<DataType>(sliced_dstr);
sliced_tensor.get_thread_buffer() =
tile.get_y_sliced_thread_data(sliced_y_origins, sliced_y_lengths);
return sliced_tensor;
}
template <typename DstDataType_,
typename DstStaticTileDistribution_,
typename SrcDataType_,
typename SrcStaticTileDistribution_,
index_t... SliceBegins,
index_t... SliceEnds>
CK_TILE_DEVICE constexpr auto
set_slice_tile(static_distributed_tensor<DstDataType_, DstStaticTileDistribution_>& dst_tile,
const static_distributed_tensor<SrcDataType_, SrcStaticTileDistribution_>& src_tile,
sequence<SliceBegins...> slice_begins,
sequence<SliceEnds...> slice_ends)
{
using DstDistribution = remove_cvref_t<DstStaticTileDistribution_>;
constexpr auto sliced_dstr_yidx_ylen =
detail::slice_distribution_from_x(DstDistribution{}, slice_begins, slice_ends);
constexpr auto sliced_dstr = sliced_dstr_yidx_ylen.template at<0>();
constexpr auto sliced_y_origins = sliced_dstr_yidx_ylen.template at<1>();
constexpr auto sliced_y_lengths = sliced_dstr_yidx_ylen.template at<2>();
static_assert(std::is_same_v<decltype(sliced_dstr), DstDistribution>, "wrong!");
dst_tile.SetSlicedThreadData(sliced_y_origins, sliced_y_lengths, src_tile.get_thread_buffer());
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/container/thread_buffer.hpp"
namespace ck_tile {
template <typename DataType_, typename StaticTileDistribution_>
struct static_distributed_tensor
{
using DataType = remove_cvref_t<DataType_>;
using StaticTileDistribution = remove_cvref_t<StaticTileDistribution_>;
static_assert(StaticTileDistribution::is_static(),
"wrong! StaticTileDistribution should be known at compile tile");
using ThreadTensorDesc =
remove_cvref_t<decltype(StaticTileDistribution{}.get_ys_to_d_descriptor())>;
static constexpr index_t kThreadElementSpaceSize = ThreadTensorDesc{}.get_element_space_size();
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_dimension()
{
return StaticTileDistribution::get_num_of_dimension_x();
}
CK_TILE_HOST_DEVICE static constexpr auto get_lengths()
{
return StaticTileDistribution::get_lengths();
}
CK_TILE_HOST_DEVICE static constexpr auto get_tile_distribution()
{
return StaticTileDistribution{};
}
CK_TILE_HOST_DEVICE static constexpr auto get_distributed_spans()
{
return StaticTileDistribution::get_distributed_spans();
}
CK_TILE_HOST_DEVICE void initialize(const DataType& x) { thread_buf_.initialize(x); }
CK_TILE_HOST_DEVICE constexpr const auto& get_thread_buffer() const { return thread_buf_; }
CK_TILE_HOST_DEVICE constexpr auto& get_thread_buffer() { return thread_buf_; }
CK_TILE_HOST_DEVICE static constexpr index_t get_thread_buffer_size()
{
return kThreadElementSpaceSize;
}
template <index_t... YSliceOrigins, index_t... YSliceLengths>
CK_TILE_HOST_DEVICE auto get_y_sliced_thread_data(sequence<YSliceOrigins...>,
sequence<YSliceLengths...>) const
{
static_assert(sizeof...(YSliceOrigins) == StaticTileDistribution::NDimY &&
sizeof...(YSliceLengths) == StaticTileDistribution::NDimY,
"wrong!");
constexpr auto sliced_thread_tensor_desc =
make_naive_tensor_descriptor_packed(make_tuple(YSliceLengths...));
thread_buffer<DataType, sliced_thread_tensor_desc.get_element_space_size()>
sliced_thread_data;
static_ford<sequence<YSliceLengths...>>{}([&](auto idx) {
constexpr auto idx_ys = idx + sequence<YSliceOrigins...>{};
sliced_thread_data(number<sliced_thread_tensor_desc.calculate_offset(idx)>{}) =
thread_buf_[number<ThreadTensorDesc{}.calculate_offset(idx_ys)>{}];
});
return sliced_thread_data;
}
template <index_t... YSliceOrigins, index_t... YSliceLengths, typename SlicedThreadData>
CK_TILE_HOST_DEVICE void set_y_sliced_thread_data(sequence<YSliceOrigins...>,
sequence<YSliceLengths...>,
const SlicedThreadData& sliced_thread_data)
{
static_assert(sizeof...(YSliceOrigins) == StaticTileDistribution::NDimY &&
sizeof...(YSliceLengths) == StaticTileDistribution::NDimY,
"wrong!");
constexpr auto sliced_thread_tensor_desc =
make_naive_tensor_descriptor_packed(make_tuple(YSliceLengths...));
static_ford<sequence<YSliceLengths...>>{}([&](auto idx) {
constexpr auto idx_ys = idx + sequence<YSliceOrigins...>{};
thread_buf_(number<ThreadTensorDesc{}.calculate_offset(idx_ys)>{}) =
sliced_thread_data[number<sliced_thread_tensor_desc.calculate_offset(idx)>{}];
});
}
template <typename TileDistributedIndices>
CK_TILE_HOST_DEVICE constexpr const DataType& operator[](TileDistributedIndices) const
{
static_assert(is_static_v<TileDistributedIndices>,
"wrong! Tile Distributed Indices should be static");
constexpr auto y_idx = get_tile_distribution().get_y_indices_from_distributed_indices(
TileDistributedIndices{});
return thread_buf_[number<ThreadTensorDesc{}.calculate_offset(y_idx)>{}];
}
template <typename TileDistributedIndices>
CK_TILE_HOST_DEVICE constexpr DataType& operator()(TileDistributedIndices)
{
static_assert(is_static_v<TileDistributedIndices>,
"wrong! Tile Distributed Indices should be static");
constexpr auto y_idx = get_tile_distribution().get_y_indices_from_distributed_indices(
TileDistributedIndices{});
return thread_buf_(number<ThreadTensorDesc{}.calculate_offset(y_idx)>{});
}
//
thread_buffer<DataType, kThreadElementSpaceSize> thread_buf_;
};
template <typename DataType, typename StaticTileDistribution>
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution&)
{
return static_distributed_tensor<remove_cvref_t<DataType>,
remove_cvref_t<StaticTileDistribution>>{};
}
template <typename DataType, typename StaticTileDistribution, typename ThreadBuffer>
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution&,
ThreadBuffer&& thread_buffer_)
{
return static_distributed_tensor<remove_cvref_t<DataType>,
remove_cvref_t<StaticTileDistribution>>{thread_buffer_};
}
// get X indices from tuple of tile_distributed_index<>
template <typename StaticTileDistribution, typename DistributedIndices>
CK_TILE_HOST_DEVICE constexpr auto
get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution,
DistributedIndices distributed_indices)
{
const auto partition_index = detail::get_partition_index(tile_distribution);
constexpr auto y_indices =
tile_distribution.get_y_indices_from_distributed_indices(distributed_indices);
const auto x_coord = make_tensor_adaptor_coordinate(
tile_distribution.get_ps_ys_to_xs_adaptor(),
container_concat(partition_index, to_array<ck_tile::index_t, y_indices.size()>(y_indices)));
return x_coord.get_bottom_index();
}
template <typename DataType, typename StaticTileDistribution, typename XIndicesPredicate>
CK_TILE_HOST_DEVICE void
set_tile_if(static_distributed_tensor<DataType, StaticTileDistribution>& out_tensor,
DataType value,
XIndicesPredicate predicate)
{
constexpr auto out_spans =
static_distributed_tensor<DataType, StaticTileDistribution>::get_distributed_spans();
sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(out_spans[number<1>{}], [&](auto idx1) {
constexpr auto distributed_indices = make_tuple(idx0, idx1);
const auto x_indices = get_x_indices_from_distributed_indices(StaticTileDistribution{},
distributed_indices);
if(predicate(x_indices))
{
out_tensor(distributed_indices) = value;
}
});
});
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename DataType_>
CK_TILE_DEVICE void
store_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
{
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
using TileDstr = remove_cvref_t<TileDistribution_>;
static_assert(std::is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
constexpr auto tile_dstr = TileDstr{};
auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(),
tile_window_tmp.get_window_lengths(),
tile_window_tmp.get_window_origin(),
tile_dstr);
tile_window.store(dstr_tensor);
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename DataType_>
CK_TILE_DEVICE void
store_tile_raw(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
{
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
using TileDstr = remove_cvref_t<TileDistribution_>;
static_assert(std::is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
constexpr auto tile_dstr = TileDstr{};
auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(),
tile_window_tmp.get_window_lengths(),
tile_window_tmp.get_window_origin(),
tile_dstr);
tile_window.store_raw(dstr_tensor);
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
typename DataType_>
CK_TILE_DEVICE void
store_tile(tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
{
tile_window.store(dstr_tensor);
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
typename DataType_>
CK_TILE_DEVICE void
store_tile_raw(tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
{
tile_window.store_raw(dstr_tensor);
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
// sweep over a span of a distribted tile and apply lambda function F
template <typename TileDistributedSpan_, // tile_distributed_span<...>
typename F // signature: F(tile_distributed_index<...>)
>
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F& f)
{
using DstrSpan = remove_cvref_t<TileDistributedSpan_>;
static_ford<typename DstrSpan::Impl>{}([&](auto dstr_idx_impl) {
constexpr auto dstr_idx = detail::make_tile_distributed_index(dstr_idx_impl);
f(dstr_idx);
});
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
namespace ck_tile {
// Transforms: Tuple<transforms...>
// LowerDimensionHiddenIdss : Tuple<Sequence<...>, ...>
// UpperDimensionHiddenIdss : Tuple<Sequence<...>, ...>
// BottomDimensionHiddenIds : Sequence<...>
// TopDimensionHiddenIds : Sequence<...>
template <typename Transforms,
typename LowerDimensionHiddenIdss,
typename UpperDimensionHiddenIdss,
typename BottomDimensionHiddenIds,
typename TopDimensionHiddenIds>
struct tensor_adaptor
{
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_transform()
{
return Transforms::size();
}
CK_TILE_HOST_DEVICE constexpr const auto& get_transforms() const { return transforms_; }
CK_TILE_HOST_DEVICE static constexpr auto get_lower_dimension_hidden_idss()
{
return LowerDimensionHiddenIdss{};
}
CK_TILE_HOST_DEVICE static constexpr auto get_upper_dimension_hidden_idss()
{
return UpperDimensionHiddenIdss{};
}
CK_TILE_HOST_DEVICE static constexpr auto get_bottom_dimension_hidden_ids()
{
return BottomDimensionHiddenIds{};
}
CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_hidden_ids()
{
return TopDimensionHiddenIds{};
}
CK_TILE_HOST_DEVICE static constexpr auto initialize_element_size(const Transforms& transforms)
{
const auto lengths = generate_tuple(
[&](auto idim_top) {
constexpr index_t idim_hidden = TopDimensionHiddenIds::at(idim_top);
constexpr auto tmp = get_transform_and_its_upper_dimension(number<idim_hidden>{});
constexpr index_t itran = tmp[number<0>{}];
constexpr index_t idim_up = tmp[number<1>{}];
constexpr bool found = tmp[number<2>{}];
static_assert(found == true,
"wrong! not found matching transformation and upper-dimension");
const auto length =
transforms[number<itran>{}].get_upper_lengths()[number<idim_up>{}];
return length;
},
number<ndim_top_>{});
// TODO: make container_reduce support tuple of number and index_t
return container_reduce(lengths, multiplies{}, number<1>{});
}
template <index_t IDimHidden>
CK_TILE_HOST_DEVICE static constexpr auto
get_transform_and_its_upper_dimension(number<IDimHidden>)
{
// FIXME: length of bottom dimension is not known, since info about lower dim length are not
// saved in transformation
static_assert(IDimHidden >= ndim_bottom_, "wrong! not implemented");
index_t itran_found = 0;
index_t idim_up_found = 0;
bool found = false;
static_for<0, ntransform_, 1>{}([&](auto itran) {
constexpr auto up_dim_ids = UpperDimensionHiddenIdss{}[itran];
static_for<0, up_dim_ids.size(), 1>{}([&](auto idim_up) {
if constexpr(up_dim_ids[idim_up] == IDimHidden)
{
itran_found = itran;
idim_up_found = idim_up;
found = true;
}
});
});
return make_tuple(itran_found, idim_up_found, found);
}
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_bottom_dimension()
{
return BottomDimensionHiddenIds::size();
}
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_top_dimension()
{
return TopDimensionHiddenIds::size();
}
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_hidden_dimension()
{
constexpr auto all_low_dim_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); },
LowerDimensionHiddenIdss{});
constexpr auto all_up_dim_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); },
UpperDimensionHiddenIdss{});
constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids);
using unique_sort_all_dim_ids = typename sequence_unique_sort<decltype(all_dim_ids),
less<index_t>,
equal<index_t>>::type;
return unique_sort_all_dim_ids::size();
}
constexpr static index_t ntransform_ = get_num_of_transform();
constexpr static index_t ndim_hidden_ = get_num_of_hidden_dimension();
constexpr static index_t ndim_bottom_ = get_num_of_bottom_dimension();
constexpr static index_t ndim_top_ = get_num_of_top_dimension();
using HiddenIndex = multi_index<ndim_hidden_>;
using BottomIndex = multi_index<ndim_bottom_>;
using TopIndex = multi_index<ndim_top_>;
// may be index_t or number<>
using ElementSize = remove_cv_t<decltype(initialize_element_size(Transforms{}))>;
public:
CK_TILE_HOST_DEVICE constexpr tensor_adaptor() = default;
CK_TILE_HOST_DEVICE constexpr tensor_adaptor(const Transforms& transforms)
: transforms_{transforms}, element_size_{initialize_element_size(transforms)}
{
static_assert(Transforms::size() == ntransform_ &&
LowerDimensionHiddenIdss::size() == ntransform_ &&
UpperDimensionHiddenIdss::size() == ntransform_,
"wrong! inconsistent # of transformations");
// TODO check dependency of dimensions is valid
}
CK_TILE_HOST_DEVICE constexpr auto get_element_size() const { return element_size_; }
// FIXME: this logic is wrong when getting bottome dimension lengths
template <index_t IDimHidden>
CK_TILE_HOST_DEVICE constexpr auto get_hidden_dimension_length(number<IDimHidden>) const
{
static_assert(IDimHidden >= 0 && IDimHidden < ndim_hidden_, "wrong! out of range");
constexpr auto tmp = get_transform_and_its_upper_dimension(number<IDimHidden>{});
constexpr index_t itran = tmp[number<0>{}];
constexpr index_t idim_up = tmp[number<1>{}];
constexpr bool found = tmp[number<2>{}];
static_assert(found == true,
"wrong! not found matching transformation and upper-dimension");
return transforms_[number<itran>{}].get_upper_lengths()[number<idim_up>{}];
}
template <index_t IDimTop>
CK_TILE_HOST_DEVICE constexpr auto get_top_dimension_length(number<IDimTop> idim_top) const
{
return get_hidden_dimension_length(TopDimensionHiddenIds::at(idim_top));
}
#if 0
// FIXME: get_hidden_dimension_length is wrong when getting bottome dimension lengths
template <index_t IDimBottom>
CK_TILE_HOST_DEVICE constexpr index_t
get_bottom_dimension_length(number<IDimBottom> idim_bottom) const
{
return get_hidden_dimension_length(TopDimensionHiddenIds::at(idim_bottom));
}
#endif
CK_TILE_HOST_DEVICE constexpr auto get_top_dimension_lengths() const
{
return generate_tuple([&](auto i) { return get_top_dimension_length(i); },
number<ndim_top_>{});
}
#if 0
// FIXME: get_hidden_dimension_length is wrong when getting bottome dimension lengths
CK_TILE_HOST_DEVICE constexpr auto GetBottomDimensionLengths() const
{
return generate_tuple([&](auto i) { return get_bottom_dimension_length(i); },
number<ndim_bottom_>{});
}
#endif
template <typename TopIdx>
CK_TILE_HOST_DEVICE constexpr auto calculate_bottom_index(const TopIdx& idx_top) const
{
static_assert(TopIdx::size() == TopDimensionHiddenIds::size(),
"wrong! # of dimension inconsistent");
constexpr index_t ntransform = get_num_of_transform();
constexpr index_t ndim_hidden = get_num_of_hidden_dimension();
multi_index<ndim_hidden> idx_hidden;
// initialize uppest index
set_container_subset(idx_hidden, get_top_dimension_hidden_ids(), idx_top);
// calculate hidden index
static_for<ntransform, 0, -1>{}([&](auto itran_p1) {
auto itran = itran_p1 - number<1>{};
const auto& tran = get_transforms().at(itran);
constexpr auto dims_low = get_lower_dimension_hidden_idss().at(itran);
constexpr auto dims_up = get_upper_dimension_hidden_idss().at(itran);
const auto idx_up = get_container_subset(idx_hidden, dims_up);
multi_index<dims_low.size()> idx_low;
tran.calculate_lower_index(idx_low, idx_up);
set_container_subset(idx_hidden, dims_low, idx_low);
});
return get_container_subset(idx_hidden, BottomDimensionHiddenIds{});
}
CK_TILE_HOST_DEVICE static constexpr bool is_static()
{
bool is_known = true;
static_for<0, Transforms::size(), 1>{}([&](auto i) {
is_known &= remove_cvref_t<decltype(Transforms{}[i])>::is_known_at_compile_time();
});
return is_known && ck_tile::is_known_at_compile_time<ElementSize>::value;
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() { return is_static(); }
CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_safe_vector_length_strides(
const array<index_t, ndim_hidden_>& guaranteed_vector_lengths,
const array<index_t, ndim_hidden_>& guaranteed_vector_strides)
{
auto vector_lengths = guaranteed_vector_lengths;
auto vector_strides = guaranteed_vector_strides;
static_for<0, get_num_of_transform(), 1>{}([&](auto itran) {
constexpr auto low_dims = get_lower_dimension_hidden_idss().at(itran);
constexpr auto up_dims = get_upper_dimension_hidden_idss().at(itran);
const auto up_guaranteed_vector_lengths =
get_container_subset(guaranteed_vector_lengths, up_dims);
const auto up_guaranteed_vector_strides =
get_container_subset(guaranteed_vector_strides, up_dims);
// only need type of transform
auto [up_vector_lengths, up_vector_strides] =
Transforms{}.at(itran).calculate_upper_dimension_safe_vector_length_strides(
get_container_subset(vector_lengths, low_dims),
get_container_subset(vector_strides, low_dims));
if constexpr(up_dims.size() > 0)
{
for(index_t i = 0; i < up_dims.size(); ++i)
{
up_vector_lengths(i) = (up_guaranteed_vector_lengths[i] != -1)
? up_guaranteed_vector_lengths[i]
: up_vector_lengths[i];
up_vector_strides(i) = (up_guaranteed_vector_strides[i] != -1)
? up_guaranteed_vector_strides[i]
: up_vector_strides[i];
}
}
set_container_subset(vector_lengths, up_dims, up_vector_lengths);
set_container_subset(vector_strides, up_dims, up_vector_strides);
});
constexpr auto top_dims = TopDimensionHiddenIds{};
return make_tuple(get_container_subset(vector_lengths, top_dims),
get_container_subset(vector_strides, top_dims));
}
CK_TILE_HOST_DEVICE void print() const
{
printf("tensor_adaptor{");
//
printf("transforms: ");
print(transforms_);
printf(", ");
//
printf("LowerDimensionHiddenIds: ");
print(LowerDimensionHiddenIdss{});
printf(", ");
//
printf("UpperDimensionHiddenIds: ");
print(UpperDimensionHiddenIdss{});
printf(", ");
//
printf("BottomDimensionHiddenIds: ");
print(BottomDimensionHiddenIds{});
printf(", ");
//
printf("TopDimensionHiddenIds: ");
print(TopDimensionHiddenIds{});
printf("}");
}
private:
Transforms transforms_;
ElementSize element_size_;
};
// Transforms: Tuple<transforms...>
// LowerDimensionOldTopIdss: Tuple<Sequence<...>, ...>
// UpperDimensionNewTopIdss: Tuple<Sequence<...>, ...>
template <typename Transforms, typename LowerDimensionOldTopIdss, typename UpperDimensionNewTopIdss>
CK_TILE_HOST_DEVICE constexpr auto make_single_stage_tensor_adaptor(const Transforms& transforms,
LowerDimensionOldTopIdss,
UpperDimensionNewTopIdss)
{
constexpr index_t ntransform = Transforms::size();
static_assert(LowerDimensionOldTopIdss::size() == ntransform &&
UpperDimensionNewTopIdss::size() == ntransform,
"wrong!");
// sanity check on LowerDimensionOldTopIdss and UpperDimensionNewTopIdss
constexpr auto all_low_dim_old_top_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionOldTopIdss{});
constexpr auto all_up_dim_new_top_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionNewTopIdss{});
static_assert(is_valid_sequence_map<decltype(all_low_dim_old_top_ids)>::value &&
is_valid_sequence_map<decltype(all_up_dim_new_top_ids)>::value,
"wrong!");
constexpr index_t ndim_old_top = all_low_dim_old_top_ids.size();
constexpr index_t ndim_new_top = all_up_dim_new_top_ids.size();
// low_dim_hidden_idss
constexpr auto low_dim_hidden_idss = LowerDimensionOldTopIdss{};
// up_dim_hidden_idss: shift UpperDimensionNewTopIdss by ndim_bottom
constexpr auto up_dim_hidden_idss = generate_tuple(
[](auto itran) { return UpperDimensionNewTopIdss{}[itran] + number<ndim_old_top>{}; },
number<ntransform>{});
// bottom_dim_hidden_ids
constexpr auto bottom_dim_hidden_ids =
typename arithmetic_sequence_gen<0, ndim_old_top, 1>::type{};
// top_dim_hidden_ids
constexpr auto top_dim_hidden_ids =
typename arithmetic_sequence_gen<0, ndim_new_top, 1>::type{} + number<ndim_old_top>{};
return tensor_adaptor<remove_cvref_t<Transforms>,
remove_cvref_t<decltype(low_dim_hidden_idss)>,
remove_cvref_t<decltype(up_dim_hidden_idss)>,
remove_cvref_t<decltype(bottom_dim_hidden_ids)>,
remove_cvref_t<decltype(top_dim_hidden_ids)>>{transforms};
}
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor, and to put it outside the scope where it is used
// (transform_tensor_adaptor) because template cannot be defined inside a function
// template
template <typename NewTransforms>
struct lambda_get_up_dim_num
{
template <typename I>
CK_TILE_HOST_DEVICE constexpr auto operator()(I) const
{
using Tran = remove_reference_t<decltype(NewTransforms{}.at(I{}))>;
return number<Tran::get_num_of_upper_dimension()>{};
}
};
template <typename OldTensorAdaptor,
typename NewTransforms,
typename NewLowerDimensionOldTopIdss,
typename NewUpperDimensionNewTopIdss>
CK_TILE_HOST_DEVICE constexpr auto
transform_tensor_adaptor(const OldTensorAdaptor& old_tensor_adaptor,
const NewTransforms& new_transforms,
NewLowerDimensionOldTopIdss,
NewUpperDimensionNewTopIdss)
{
// sanity check
{
static_assert(NewTransforms::size() == NewLowerDimensionOldTopIdss::size() &&
NewTransforms::size() == NewUpperDimensionNewTopIdss::size(),
"wrong! inconsitent number of transform");
constexpr auto all_old_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
NewLowerDimensionOldTopIdss{});
constexpr auto all_new_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
NewUpperDimensionNewTopIdss{});
static_assert(is_valid_sequence_map<decltype(all_old_top_ids)>::value &&
is_valid_sequence_map<decltype(all_new_top_ids)>::value,
"wrong!");
}
// lower dimension's hidden idss
// convert lower dimension top idss (tuple of sequences) to hidden idss (tuple of
// sequences)
constexpr auto low_dim_hidden_idss = transform_tuples(
// convert lower dimension top ids (a sequence) to hidden ids (a sequence)
[](auto low_dim_top_ids) constexpr {
return transform_sequences(
// convert lower dimension top id to hidden id
[](auto low_dim_top_id) constexpr {
return OldTensorAdaptor::get_top_dimension_hidden_ids()[low_dim_top_id];
},
low_dim_top_ids);
},
NewLowerDimensionOldTopIdss{});
constexpr index_t num_new_transform = NewTransforms::size();
// upper dimension's hidden idss
constexpr index_t old_hidden_dim_number = OldTensorAdaptor::get_num_of_hidden_dimension();
constexpr auto up_dim_numbers =
generate_sequence(lambda_get_up_dim_num<NewTransforms>{}, number<num_new_transform>{});
constexpr auto up_dim_numbers_scan = merge_sequences(
sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, plus<index_t>{}, number<0>{}));
constexpr auto up_dim_hidden_idss = generate_tuple(
[ old_hidden_dim_number, up_dim_numbers_scan ](auto i) constexpr {
return
typename arithmetic_sequence_gen<old_hidden_dim_number + up_dim_numbers_scan[i],
old_hidden_dim_number + up_dim_numbers_scan[i + 1],
1>::type{};
},
number<num_new_transform>{});
// new top dimension's hidden ids
constexpr auto unordered_new_top_dim_hidden_ids = unpack(
[](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss);
constexpr auto new_top_dim_unordered2ordered = unpack(
[](auto... xs) constexpr { return merge_sequences(xs...); }, NewUpperDimensionNewTopIdss{});
constexpr auto new_top_dim_hidden_ids =
unordered_new_top_dim_hidden_ids.reorder_old_to_new(new_top_dim_unordered2ordered);
// put everything together
const auto all_transforms =
container_concat(old_tensor_adaptor.get_transforms(), new_transforms);
constexpr auto all_low_dim_hidden_idss =
container_concat(OldTensorAdaptor::get_lower_dimension_hidden_idss(), low_dim_hidden_idss);
constexpr auto all_up_dim_hidden_idss =
container_concat(OldTensorAdaptor::get_upper_dimension_hidden_idss(), up_dim_hidden_idss);
return tensor_adaptor<
remove_cvref_t<decltype(all_transforms)>,
remove_cvref_t<decltype(all_low_dim_hidden_idss)>,
remove_cvref_t<decltype(all_up_dim_hidden_idss)>,
remove_cvref_t<decltype(OldTensorAdaptor::get_bottom_dimension_hidden_ids())>,
remove_cvref_t<decltype(new_top_dim_hidden_ids)>>{all_transforms};
}
template <typename TensorAdaptor0, typename TensorAdaptor1>
CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& adaptor0,
const TensorAdaptor1& adaptor1)
{
static_assert(TensorAdaptor0::get_num_of_top_dimension() ==
TensorAdaptor1::get_num_of_bottom_dimension(),
"wrong!");
// all_transforms = transform0 + transform1
const auto all_transforms =
container_concat(adaptor0.get_transforms(), adaptor1.get_transforms());
// shift
constexpr index_t adaptor0_max_hidden_id = [&]() {
index_t adaptor0_max_hidden_id_ = numeric<index_t>::min();
static_for<0, TensorAdaptor0::get_num_of_transform(), 1>{}([&](auto itran) {
constexpr index_t ndim_low =
TensorAdaptor0{}.get_transforms()[itran].get_num_of_lower_dimension();
static_for<0, ndim_low, 1>{}([&](auto idim_low) {
adaptor0_max_hidden_id_ =
max(adaptor0_max_hidden_id_,
TensorAdaptor0::get_lower_dimension_hidden_idss()[itran][idim_low].value);
});
constexpr index_t ndim_up =
TensorAdaptor0{}.get_transforms()[itran].get_num_of_upper_dimension();
static_for<0, ndim_up, 1>{}([&](auto idim_up) {
adaptor0_max_hidden_id_ =
max(adaptor0_max_hidden_id_,
TensorAdaptor0::get_upper_dimension_hidden_idss()[itran][idim_up].value);
});
});
return adaptor0_max_hidden_id_;
}();
constexpr index_t adaptor1_min_hidden_id = [&]() {
index_t adaptor1_min_hidden_id_ = numeric<index_t>::max();
static_for<0, TensorAdaptor1::get_num_of_transform(), 1>{}([&](auto itran) {
constexpr index_t ndim_low =
TensorAdaptor1{}.get_transforms()[itran].get_num_of_lower_dimension();
// get the min of all lower dimenions, but not bottom dimension (because their id will
// be matched with top id from adaptor0)
static_for<0, ndim_low, 1>{}([&](auto idim_low) {
constexpr index_t low_dim_hidden_id =
TensorAdaptor1::get_lower_dimension_hidden_idss()[itran][idim_low].value;
bool is_bottom_dim = false;
static_for<0, TensorAdaptor1::get_num_of_bottom_dimension(), 1>{}([&](auto i) {
if constexpr(low_dim_hidden_id ==
TensorAdaptor1::get_bottom_dimension_hidden_ids()[i])
{
is_bottom_dim = true;
}
});
if(!is_bottom_dim)
{
adaptor1_min_hidden_id_ = min(adaptor1_min_hidden_id_, low_dim_hidden_id);
}
});
constexpr index_t ndim_up =
TensorAdaptor1{}.get_transforms()[itran].get_num_of_upper_dimension();
// get the min of all upper dimensions
static_for<0, ndim_up, 1>{}([&](auto idim_up) {
adaptor1_min_hidden_id_ =
min(adaptor1_min_hidden_id_,
TensorAdaptor1::get_upper_dimension_hidden_idss()[itran][idim_up].value);
});
});
return adaptor1_min_hidden_id_;
}();
constexpr index_t adaptor1_hidden_id_shift =
adaptor0_max_hidden_id + 1 - adaptor1_min_hidden_id;
constexpr index_t ndim_bottom_1 = TensorAdaptor1::get_num_of_bottom_dimension();
// all_low_dim_hidden_idss =
// low_dim_hidden_idss_0 + match_hidden_id_for_1(shift_hidden_id_for_1(low_dim_hiden_idss_1))
constexpr auto low_dim_hidden_idss_1 = generate_tuple(
// generate sequence of ids for a transform
[&](auto itran) {
constexpr auto ndim_low_1 =
TensorAdaptor1::get_lower_dimension_hidden_idss()[itran].size();
constexpr auto low_dim_hidden_ids_1 =
TensorAdaptor1::get_lower_dimension_hidden_idss()[itran];
// sequence in, sequence out
constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr
{
auto low_dim_hidden_ids_1_mod_ = to_multi_index(low_dim_hidden_ids_1);
// shift hidden id so every dim id is unique
static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) {
low_dim_hidden_ids_1_mod_(idim_low_1) += adaptor1_hidden_id_shift;
});
// match hidden id
static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) {
static_for<0, ndim_bottom_1, 1>{}([&](auto idim_bottom_1) {
// if this low dim is bottom dim, then do id matching
if constexpr(low_dim_hidden_ids_1[idim_low_1] ==
TensorAdaptor1::get_bottom_dimension_hidden_ids()
[idim_bottom_1])
{
low_dim_hidden_ids_1_mod_(idim_low_1) =
TensorAdaptor0::get_top_dimension_hidden_ids()[idim_bottom_1];
}
});
});
return low_dim_hidden_ids_1_mod_;
}
();
return generate_sequence_v2(
[&](auto i) constexpr { return number<low_dim_hidden_ids_1_mod[i]>{}; },
number<ndim_low_1>{});
},
number<TensorAdaptor1::get_num_of_transform()>{});
constexpr auto all_low_dim_hidden_idss =
container_concat(TensorAdaptor0::get_lower_dimension_hidden_idss(), low_dim_hidden_idss_1);
// all_up_dim_hidden_idss =
// up_dim_hidden_idss_0 + shift_hidden_id_for_1(up_dim_hiden_idss_1)
constexpr auto up_dim_hidden_idss_1 = generate_tuple(
// generate sequence of ids for a transform
[&](auto itran) {
constexpr auto ndim_up_1 =
TensorAdaptor1::get_upper_dimension_hidden_idss()[itran].size();
constexpr auto up_dim_hidden_ids_1 =
TensorAdaptor1::get_upper_dimension_hidden_idss()[itran];
// sequence in, constexpr tuple out
constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr
{
auto up_dim_hidden_ids_1_mod_ = to_multi_index(up_dim_hidden_ids_1);
// shift hidden id
static_for<0, ndim_up_1, 1>{}([&](auto idim_up_1) {
up_dim_hidden_ids_1_mod_(idim_up_1) += adaptor1_hidden_id_shift;
});
return up_dim_hidden_ids_1_mod_;
}
();
// constexpr tuple to sequence
return generate_sequence_v2(
[&](auto i) constexpr { return number<up_dim_hidden_ids_1_mod[i]>{}; },
number<ndim_up_1>{});
},
number<TensorAdaptor1::get_num_of_transform()>{});
constexpr auto all_up_dim_hidden_idss =
container_concat(TensorAdaptor0::get_upper_dimension_hidden_idss(), up_dim_hidden_idss_1);
// bottom_dim_hidden_ids = bottom_dim_hidden_ids_0
constexpr auto bottom_dim_hidden_ids = TensorAdaptor0::get_bottom_dimension_hidden_ids();
// top_dim_hidden_ids = shift_hidden_id(top_dim_hidden_ids_1)
constexpr auto top_dim_hidden_ids =
TensorAdaptor1::get_top_dimension_hidden_ids() + number<adaptor1_hidden_id_shift>{};
// put everything together
return tensor_adaptor<remove_cvref_t<decltype(all_transforms)>,
remove_cvref_t<decltype(all_low_dim_hidden_idss)>,
remove_cvref_t<decltype(all_up_dim_hidden_idss)>,
remove_cvref_t<decltype(bottom_dim_hidden_ids)>,
remove_cvref_t<decltype(top_dim_hidden_ids)>>{all_transforms};
}
template <typename X,
typename... Xs,
typename std::enable_if<sizeof...(Xs) >= 2, bool>::type = false>
CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&... xs)
{
return chain_tensor_adaptors(x, chain_tensor_adaptors(xs...));
}
} // namespace ck_tile
// Macro function
// construct constexpr tensor_adaptor from constexpr encoding
// encoded_tensor_adaptor are Tuple of following objects:
// 1. encoded transforms (array of fixed size). Each encoded transform is a Tuple of following:
// 1.1 name (coord_transform_enum)
// 1.2 meta data for constructor of the transform
// 1.3 num of lower dimension (index_t)
// 1.4 lower dimension Ids (array of fixed size)
// 1.5 num of up dimension (index_t)
// 1.6 upper dimension Ids (array of fixed size)
// 2. num of transforms (index_t)
// 3. encoded bottom dimension Ids (array of fixed size)
// 4. num of bottom dimension (index_t)
// 5. encoded top dimension Ids (array of fixed size)
// 6. num of top dimension (index_t)
#define CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(encoded_tensor_adaptor) \
[encoded_tensor_adaptor]() { \
using namespace ck_tile; \
\
constexpr auto encoded_transforms = encoded_tensor_adaptor.template at<0>(); \
constexpr index_t num_transform = encoded_tensor_adaptor.template at<1>(); \
constexpr auto encoded_bottom_dims = encoded_tensor_adaptor.template at<2>(); \
constexpr index_t num_bottom_dim = encoded_tensor_adaptor.template at<3>(); \
constexpr auto encoded_top_dims = encoded_tensor_adaptor.template at<4>(); \
constexpr index_t num_top_dim = encoded_tensor_adaptor.template at<5>(); \
\
constexpr auto trans = [&encoded_transforms]() { \
return generate_tuple( \
[&encoded_transforms](auto i) constexpr { \
constexpr auto name = encoded_transforms[i].template at<0>(); \
constexpr auto meta_data = encoded_transforms[i].template at<1>(); \
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
\
static_assert(name == coord_transform_enum::pass_through || \
name == coord_transform_enum::pad || \
name == coord_transform_enum::embed || \
name == coord_transform_enum::merge || \
name == coord_transform_enum::unmerge || \
name == coord_transform_enum::replicate, \
""); \
\
if constexpr(name == coord_transform_enum::pass_through) \
{ \
index_t pos = 0; \
auto low_len = meta_data.template pop<index_t>(pos); \
\
return make_pass_through_transform(low_len); \
} \
else if constexpr(name == coord_transform_enum::pad) \
{ \
index_t pos = 0; \
auto low_len = meta_data.template pop<index_t>(pos); \
auto left_pad = meta_data.template pop<index_t>(pos); \
auto right_pad = meta_data.template pop<index_t>(pos); \
\
return make_pad_transform(low_len, left_pad, right_pad); \
} \
else if constexpr(name == coord_transform_enum::embed) \
{ \
index_t pos = 0; \
auto up_lens = meta_data.template pop<array<index_t, num_up_dim>>(pos); \
auto coefficients = \
meta_data.template pop<array<index_t, num_up_dim>>(pos); \
\
return make_embed_transform(up_lens, coefficients); \
} \
else if constexpr(name == coord_transform_enum::merge) \
{ \
index_t pos = 0; \
auto low_lens = meta_data.template pop<array<index_t, num_low_dim>>(pos); \
\
return make_merge_transform(low_lens); \
} \
else if constexpr(name == coord_transform_enum::unmerge) \
{ \
index_t pos = 0; \
auto up_lens = meta_data.template pop<array<index_t, num_up_dim>>(pos); \
\
return make_unmerge_transform(up_lens); \
} \
else if constexpr(name == coord_transform_enum::replicate) \
{ \
index_t pos = 0; \
auto up_lens = meta_data.template pop<array<index_t, num_up_dim>>(pos); \
\
return make_replicate_transform(up_lens); \
} \
}, \
number<num_transform>{}); \
}(); \
\
constexpr auto low_dim_idss = [&encoded_transforms, &num_transform]() { \
return generate_tuple( \
[&encoded_transforms](auto i) { \
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
constexpr auto low_dims = encoded_transforms[i].template at<3>(); \
\
return TO_SEQUENCE(low_dims, num_low_dim); \
}, \
number<num_transform>()); \
}(); \
\
constexpr auto up_dim_idss = [&encoded_transforms, &num_transform] { \
return generate_tuple( \
[&encoded_transforms](auto i) { \
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
constexpr auto up_dims = encoded_transforms[i].template at<5>(); \
\
return TO_SEQUENCE(up_dims, num_up_dim); \
}, \
number<num_transform>()); \
}(); \
\
constexpr auto bottom_dim_ids = TO_SEQUENCE(encoded_bottom_dims, num_bottom_dim); \
constexpr auto top_dim_ids = TO_SEQUENCE(encoded_top_dims, num_top_dim); \
\
return tensor_adaptor<remove_cvref_t<decltype(trans)>, \
remove_cvref_t<decltype(low_dim_idss)>, \
remove_cvref_t<decltype(up_dim_idss)>, \
remove_cvref_t<decltype(bottom_dim_ids)>, \
remove_cvref_t<decltype(top_dim_ids)>>{trans}; \
}()
// Macro function
// construct static tensor_adaptor from constexpr encoding
// encoded_tensor_adaptor are Tuple of following objects:
// 1. encoded transforms (array of fixed size). Each encoded transform is a Tuple of following:
// 1.1 name (coord_transform_enum)
// 1.2 meta data for constructor of the transform
// 1.3 num of lower dimension (index_t)
// 1.4 lower dimension Ids (array of fixed size)
// 1.5 num of up dimension (index_t)
// 1.6 upper dimension Ids (array of fixed size)
// 2. num of transforms (index_t)
// 3. encoded bottom dimension Ids (array of fixed size)
// 4. num of bottom dimension (index_t)
// 5. encoded top dimension Ids (array of fixed size)
// 6. num of top dimension (index_t)
#define CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(encoded_tensor_adaptor) \
[encoded_tensor_adaptor]() { \
using namespace ck_tile; \
\
constexpr auto encoded_transforms = encoded_tensor_adaptor.template at<0>(); \
constexpr index_t num_transform = encoded_tensor_adaptor.template at<1>(); \
constexpr auto encoded_bottom_dims = encoded_tensor_adaptor.template at<2>(); \
constexpr index_t num_bottom_dim = encoded_tensor_adaptor.template at<3>(); \
constexpr auto encoded_top_dims = encoded_tensor_adaptor.template at<4>(); \
constexpr index_t num_top_dim = encoded_tensor_adaptor.template at<5>(); \
\
constexpr auto trans = [&encoded_transforms]() { \
return generate_tuple( \
[&encoded_transforms](auto i) constexpr { \
constexpr auto name = encoded_transforms[i].template at<0>(); \
constexpr auto meta_data = encoded_transforms[i].template at<1>(); \
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
\
static_assert(name == coord_transform_enum::pass_through || \
name == coord_transform_enum::pad || \
name == coord_transform_enum::embed || \
name == coord_transform_enum::merge || \
name == coord_transform_enum::unmerge || \
name == coord_transform_enum::replicate, \
""); \
\
if constexpr(name == coord_transform_enum::pass_through) \
{ \
constexpr index_t low_len = meta_data.template get<index_t>(0); \
\
return make_pass_through_transform(number<low_len>{}); \
} \
else if constexpr(name == coord_transform_enum::pad) \
{ \
constexpr index_t low_len = meta_data.template get<index_t>(0); \
\
constexpr index_t left_pad = \
meta_data.template get<index_t>(sizeof(low_len)); \
\
constexpr index_t right_pad = \
meta_data.template pop<index_t>(sizeof(low_len) + sizeof(left_pad)); \
\
return make_pad_transform( \
number<low_len>{}, number<left_pad>{}, number<right_pad>{}); \
} \
else if constexpr(name == coord_transform_enum::embed) \
{ \
constexpr auto up_lens = \
meta_data.template get<array<index_t, num_up_dim>>(0); \
\
constexpr auto coefficients = \
meta_data.template get<array<index_t, num_up_dim>>(sizeof(up_lens)); \
\
return make_embed_transform(TO_TUPLE_OF_NUMBER(up_lens, num_up_dim), \
TO_TUPLE_OF_NUMBER(coefficients, num_up_dim)); \
} \
else if constexpr(name == coord_transform_enum::merge) \
{ \
constexpr auto low_lens = \
meta_data.template get<array<index_t, num_low_dim>>(0); \
\
return make_merge_transform(TO_TUPLE_OF_NUMBER(low_lens, num_low_dim)); \
} \
else if constexpr(name == coord_transform_enum::unmerge) \
{ \
constexpr auto up_lens = \
meta_data.template get<array<index_t, num_up_dim>>(0); \
\
return make_unmerge_transform(TO_TUPLE_OF_NUMBER(up_lens, num_up_dim)); \
} \
else if constexpr(name == coord_transform_enum::replicate) \
{ \
constexpr auto up_lens = \
meta_data.template get<array<index_t, num_up_dim>>(0); \
\
return make_replicate_transform(TO_TUPLE_OF_NUMBER(up_lens, num_up_dim)); \
} \
}, \
number<num_transform>{}); \
}(); \
\
constexpr auto low_dim_idss = [&encoded_transforms]() { \
return generate_tuple( \
[&encoded_transforms](auto i) { \
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
constexpr auto low_dims = encoded_transforms[i].template at<3>(); \
\
return TO_SEQUENCE(low_dims, num_low_dim); \
}, \
number<num_transform>()); \
}(); \
\
constexpr auto up_dim_idss = [&encoded_transforms] { \
return generate_tuple( \
[&encoded_transforms](auto i) { \
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
constexpr auto up_dims = encoded_transforms[i].template at<5>(); \
\
return TO_SEQUENCE(up_dims, num_up_dim); \
}, \
number<num_transform>()); \
}(); \
\
constexpr auto bottom_dim_ids = TO_SEQUENCE(encoded_bottom_dims, num_bottom_dim); \
constexpr auto top_dim_ids = TO_SEQUENCE(encoded_top_dims, num_top_dim); \
\
return tensor_adaptor<remove_cvref_t<decltype(trans)>, \
remove_cvref_t<decltype(low_dim_idss)>, \
remove_cvref_t<decltype(up_dim_idss)>, \
remove_cvref_t<decltype(bottom_dim_ids)>, \
remove_cvref_t<decltype(top_dim_ids)>>{trans}; \
}()
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/multi_index.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
template <index_t NDimHidden, typename BottomDimensionHiddenIds, typename TopDimensionHiddenIds>
struct tensor_adaptor_coordinate
{
static constexpr index_t ndim_bottom_ = BottomDimensionHiddenIds::size();
static constexpr index_t ndim_top_ = TopDimensionHiddenIds::size();
using HiddenIndex = multi_index<NDimHidden>;
using BottomIndex = multi_index<ndim_bottom_>;
using TopIndex = multi_index<ndim_top_>;
public:
CK_TILE_HOST_DEVICE constexpr tensor_adaptor_coordinate() = default;
CK_TILE_HOST_DEVICE constexpr tensor_adaptor_coordinate(const HiddenIndex& idx_hidden)
: idx_hidden_{idx_hidden}
{
}
CK_TILE_HOST_DEVICE constexpr auto get_top_index() const
{
return get_container_subset(idx_hidden_, TopDimensionHiddenIds{});
}
CK_TILE_HOST_DEVICE constexpr auto get_bottom_index() const
{
return get_container_subset(idx_hidden_, BottomDimensionHiddenIds{});
}
CK_TILE_HOST_DEVICE constexpr const auto& get_hidden_index() const { return idx_hidden_; }
CK_TILE_HOST_DEVICE constexpr auto& get_hidden_index() { return idx_hidden_; }
//
HiddenIndex idx_hidden_;
};
template <typename Adaptor, typename TopIndex>
CK_TILE_HOST_DEVICE constexpr auto make_tensor_adaptor_coordinate(const Adaptor& adaptor,
const TopIndex& idx_top)
{
static_assert(Adaptor::get_num_of_top_dimension() == TopIndex::size(),
"wrong! # of dimension inconsistent");
constexpr index_t ntransform = Adaptor::get_num_of_transform();
constexpr index_t ndim_hidden = Adaptor::get_num_of_hidden_dimension();
constexpr auto bottom_dim_ids = Adaptor::get_bottom_dimension_hidden_ids();
constexpr auto top_dim_ids = Adaptor::get_top_dimension_hidden_ids();
multi_index<ndim_hidden> idx_hidden;
// initialize visible index
set_container_subset(idx_hidden, top_dim_ids, idx_top);
// calculate hidden index
static_for<ntransform, 0, -1>{}([&adaptor, &idx_hidden](auto itran_p1) {
auto itran = itran_p1 - number<1>{};
const auto& tran = adaptor.get_transforms().at(itran);
constexpr auto dims_low = Adaptor::get_lower_dimension_hidden_idss().at(itran);
constexpr auto dims_up = Adaptor::get_upper_dimension_hidden_idss().at(itran);
const auto idx_up = get_container_subset(idx_hidden, dims_up);
multi_index<dims_low.size()> idx_low;
tran.calculate_lower_index(idx_low, idx_up);
set_container_subset(idx_hidden, dims_low, idx_low);
});
return tensor_adaptor_coordinate<ndim_hidden,
remove_cvref_t<decltype(bottom_dim_ids)>,
remove_cvref_t<decltype(top_dim_ids)>>{idx_hidden};
}
template <bool JudgeDoTransforms = true,
typename Adaptor,
typename AdaptorCoord,
typename TopIndex,
typename BottomIndex>
CK_TILE_HOST_DEVICE constexpr void move_tensor_adaptor_coordinate(const Adaptor& adaptor,
AdaptorCoord& coord,
const TopIndex& idx_diff_top,
BottomIndex& idx_diff_bottom)
{
constexpr index_t ndim_hidden = Adaptor::get_num_of_hidden_dimension();
constexpr index_t ndim_top = Adaptor::get_num_of_top_dimension();
// constexpr index_t ndim_bottom = Adaptor::get_num_of_bottom_dimension();
constexpr index_t ntransform = Adaptor::get_num_of_transform();
// static_assert(TopIndex::size() == ndim_top && BottomIndex::size() == ndim_bottom, "");
// judge whether calculation of lower diff is needed for each transform
// use index_t for boolean type
auto do_transforms = make_zero_multi_index<ntransform>();
if constexpr(JudgeDoTransforms)
{
auto is_non_zero_diff = make_zero_multi_index<ndim_hidden>();
// decide do_transform by checkout non-zero index diff components
multi_index<ndim_top> non_zero_diff_pick_top;
static_for<0, ndim_top, 1>{}(
[&](auto i) { non_zero_diff_pick_top(i) = (idx_diff_top[i] != 0); });
set_container_subset(
is_non_zero_diff, Adaptor::get_top_dimension_hidden_ids(), non_zero_diff_pick_top);
static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
constexpr auto dims_low = Adaptor::get_lower_dimension_hidden_idss().at(itran);
constexpr auto dims_up = Adaptor::get_upper_dimension_hidden_idss().at(itran);
const auto non_zero_diff_pick_up = get_container_subset(is_non_zero_diff, dims_up);
multi_index<dims_low.size()> non_zero_diff_pick_low;
// if any of upper index diff components is non-zero, then
// 1) Need to do this transform
// 2) all components of lower index diff will assume to be non-zero and need to be
// computed
const bool idx_diff_up_has_non_zero = container_reduce(
non_zero_diff_pick_up, [](auto a, auto b) constexpr { return a or b; }, false);
do_transforms(itran) = idx_diff_up_has_non_zero;
static_for<0, dims_low.size(), 1>{}(
[&](auto i) { non_zero_diff_pick_low(i) = idx_diff_up_has_non_zero; });
set_container_subset(is_non_zero_diff, dims_low, non_zero_diff_pick_low);
});
}
else
{
static_for<ntransform - 1, -1, -1>{}([&](auto itran) { do_transforms(itran) = 1; });
}
// this is what needs to be calculated
auto idx_diff_hidden = make_zero_multi_index<ndim_hidden>();
// initialize top index diff
set_container_subset(idx_diff_hidden, Adaptor::get_top_dimension_hidden_ids(), idx_diff_top);
// this is what needs to be updated
auto& idx_hidden = coord.get_hidden_index();
// update top index
auto idx_hidden_pick_top =
get_container_subset(idx_hidden, Adaptor::get_top_dimension_hidden_ids());
idx_hidden_pick_top += idx_diff_top;
set_container_subset(idx_hidden, Adaptor::get_top_dimension_hidden_ids(), idx_hidden_pick_top);
// update rest of hidden index
static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
if(do_transforms[itran])
{
const auto& tran = adaptor.get_transforms().at(itran);
constexpr auto dims_low = Adaptor::get_lower_dimension_hidden_idss().at(itran);
constexpr auto dims_up = Adaptor::get_upper_dimension_hidden_idss().at(itran);
const auto idx_up_new = get_container_subset(idx_hidden, dims_up);
auto idx_low = get_container_subset(idx_hidden, dims_low);
const auto idx_diff_up = get_container_subset(idx_diff_hidden, dims_up);
multi_index<dims_low.size()> idx_diff_low;
tran.update_lower_index(idx_diff_low, idx_diff_up, idx_low, idx_up_new);
set_container_subset(idx_diff_hidden, dims_low, idx_diff_low);
set_container_subset(idx_hidden, dims_low, idx_low);
}
});
// set bottom index diff
idx_diff_bottom =
get_container_subset(idx_diff_hidden, Adaptor::get_bottom_dimension_hidden_ids());
}
template <bool JudgeDoTransforms = true, typename Adaptor, typename AdaptorCoord, typename TopIndex>
CK_TILE_HOST_DEVICE constexpr void move_tensor_adaptor_coordinate(const Adaptor& adaptor,
AdaptorCoord& coord,
const TopIndex& idx_diff_top)
{
constexpr index_t ndim_bottom = Adaptor::get_num_of_bottom_dimension();
multi_index<ndim_bottom> tmp;
move_tensor_adaptor_coordinate<JudgeDoTransforms>(adaptor, coord, idx_diff_top, tmp);
}
template <typename Adaptor, typename AdaptorCoord>
CK_TILE_HOST_DEVICE constexpr bool
adaptor_coordinate_is_valid_assuming_top_index_is_valid(const Adaptor& adaptor,
const AdaptorCoord& coord)
{
bool valid = true;
constexpr index_t ntransform = Adaptor::get_num_of_transform();
const auto& idx_hidden = coord.get_hidden_index();
static_for<ntransform - 1, -1, -1>{}([&adaptor, &idx_hidden, &valid](auto itran) {
const auto tran = adaptor.get_transforms().at(itran);
// check validity, only if current transformation does not always has a valid mapping
if constexpr(!decltype(tran)::is_valid_upper_index_always_mapped_to_valid_lower_index())
{
const auto idx_up = get_container_subset(
idx_hidden, Adaptor::get_upper_dimension_hidden_idss().at(itran));
// Comment: using valid = valid && .. will result in weird control flow in ISA
valid &= tran.is_valid_upper_index_mapped_to_valid_lower_index(idx_up);
}
});
return valid;
}
template <typename Adaptor, typename AdpatorCoord>
CK_TILE_HOST_DEVICE constexpr bool adaptor_coordinate_is_valid(const Adaptor& adaptor,
const AdpatorCoord& coord)
{
// check top index
const auto& idx_top = coord.get_top_index();
bool is_top_index_valid = true;
static_for<0, Adaptor::get_num_of_dimension(), 1>{}(
[&is_top_index_valid, &idx_top, &adaptor](auto i) {
is_top_index_valid =
is_top_index_valid && (idx_top[i] >= 0 && idx_top[i] < adaptor.get_length(i));
});
// check other hidden index
return is_top_index_valid &&
adaptor_coordinate_is_valid_assuming_top_index_is_valid(adaptor, coord);
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tensor_adaptor_coordinate.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/multi_index.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
template <index_t NDimHidden, typename TopDimensionHiddenIds>
struct tensor_coordinate
: public tensor_adaptor_coordinate<NDimHidden, sequence<0>, TopDimensionHiddenIds>
{
using Base = tensor_adaptor_coordinate<NDimHidden, sequence<0>, TopDimensionHiddenIds>;
// TODO make these private
static constexpr index_t ndim_top_ = TopDimensionHiddenIds::size();
using HiddenIndex = multi_index<NDimHidden>;
using TopIndex = multi_index<ndim_top_>;
public:
CK_TILE_HOST_DEVICE constexpr tensor_coordinate() = default;
CK_TILE_HOST_DEVICE constexpr tensor_coordinate(const HiddenIndex& idx_hidden)
: Base{idx_hidden}
{
}
// construct from TensorAdaptorCoordinte base class
CK_TILE_HOST_DEVICE constexpr tensor_coordinate(const Base& adaptor_coord) : Base{adaptor_coord}
{
}
CK_TILE_HOST_DEVICE constexpr auto get_index() const { return Base::get_top_index(); }
CK_TILE_HOST_DEVICE constexpr index_t get_offset() const
{
return Base::get_bottom_index()[number<0>{}];
}
CK_TILE_HOST_DEVICE constexpr const auto& get_hidden_index() const
{
return Base::get_hidden_index();
}
CK_TILE_HOST_DEVICE auto& get_hidden_index() { return Base::get_hidden_index(); }
};
template <typename TensorDesc, typename TopIndex>
CK_TILE_HOST_DEVICE constexpr auto make_tensor_coordinate(const TensorDesc& tensor_desc,
const TopIndex& idx_top)
{
const auto adaptor_coord = make_tensor_adaptor_coordinate(tensor_desc, idx_top);
return tensor_coordinate<TensorDesc::get_num_of_hidden_dimension(),
remove_cvref_t<decltype(TensorDesc::get_top_dimension_hidden_ids())>>{
adaptor_coord};
}
template <bool JudgeDoTransforms = true, typename TensorDesc, typename TensorCoord, typename Index>
CK_TILE_HOST_DEVICE constexpr void
move_tensor_coordinate(const TensorDesc& tensor_desc, TensorCoord& coord, const Index& coord_step)
{
move_tensor_adaptor_coordinate(tensor_desc, coord, coord_step);
}
template <typename TensorDesc, typename TensorCoord>
CK_TILE_HOST_DEVICE constexpr bool
coordinate_has_valid_offset_assuming_top_index_is_valid(const TensorDesc& tensor_desc,
const TensorCoord& coord)
{
return adaptor_coordinate_is_valid_assuming_top_index_is_valid(tensor_desc, coord);
}
template <typename TensorDesc, typename TensorCoord>
CK_TILE_HOST_DEVICE constexpr bool coordinate_has_valid_offset(const TensorDesc& tensor_desc,
const TensorCoord& coord)
{
return adaptor_coordinate_is_valid(tensor_desc, coord);
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/multi_index.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
// Transforms: Tuple<transforms...>
// LowerDimensionHiddenIdss : Tuple<sequence<...>, ...>
// UpperDimensionHiddenIdss : Tuple<sequence<...>, ...>
// TopDimensionHiddenIds> : sequence<...>
template <typename Transforms,
typename LowerDimensionHiddenIdss,
typename UpperDimensionHiddenIdss,
typename TopDimensionHiddenIds,
typename ElementSpaceSize,
typename GuaranteedVectorLengths_,
typename GuaranteedVectorSrides_>
struct tensor_descriptor : public tensor_adaptor<Transforms,
LowerDimensionHiddenIdss,
UpperDimensionHiddenIdss,
sequence<0>,
TopDimensionHiddenIds>
{
using Base = tensor_adaptor<Transforms,
LowerDimensionHiddenIdss,
UpperDimensionHiddenIdss,
sequence<0>,
TopDimensionHiddenIds>;
using ElementSpaceSizeType = ElementSpaceSize;
constexpr static index_t ntransform_ = Base::get_num_of_transform();
constexpr static index_t ndim_hidden_ = Base::get_num_of_hidden_dimension();
constexpr static index_t ndim_top_ = Base::get_num_of_top_dimension();
using GuaranteedVectorLengths = GuaranteedVectorLengths_;
using GuaranteedVectorStrides = GuaranteedVectorSrides_;
static_assert(GuaranteedVectorLengths::size() == ndim_hidden_ &&
GuaranteedVectorStrides::size() == ndim_hidden_,
"wrong! inconsistent # of hidden dimensions");
using TopIndex = multi_index<ndim_top_>;
using HiddenIndex = multi_index<ndim_hidden_>;
public:
CK_TILE_HOST_DEVICE constexpr tensor_descriptor() = default;
CK_TILE_HOST_DEVICE constexpr tensor_descriptor(const Transforms& transforms,
ElementSpaceSize element_space_size)
: Base{transforms}, element_space_size_{element_space_size}
{
static_assert(Transforms::size() == ntransform_ &&
LowerDimensionHiddenIdss::size() == ntransform_ &&
UpperDimensionHiddenIdss::size() == ntransform_,
"wrong! inconsistent # of transformations");
// TODO check dependency of dimensions is valid
}
// construct from tensor_adaptor base class
CK_TILE_HOST_DEVICE constexpr tensor_descriptor(const Base& adaptor,
ElementSpaceSize element_space_size)
: Base{adaptor}, element_space_size_{element_space_size}
{
}
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension()
{
return Base::get_num_of_top_dimension();
}
template <index_t IDim>
CK_TILE_HOST_DEVICE constexpr auto get_length(number<IDim> idim) const
{
return Base::get_top_dimension_length(idim);
}
CK_TILE_HOST_DEVICE constexpr auto get_lengths() const
{
return Base::get_top_dimension_lengths();
}
CK_TILE_HOST_DEVICE constexpr auto get_element_space_size() const
{
return element_space_size_;
}
template <typename Idx>
CK_TILE_HOST_DEVICE constexpr index_t calculate_offset(const Idx& idx) const
{
return Base::calculate_bottom_index(idx)[number<0>{}];
}
// TODO make these private
CK_TILE_HOST_DEVICE constexpr const auto& get_transforms() const
{
return Base::get_transforms();
}
CK_TILE_HOST_DEVICE static constexpr auto get_lower_dimension_hidden_idss()
{
return Base::get_lower_dimension_hidden_idss();
}
CK_TILE_HOST_DEVICE static constexpr auto get_upper_dimension_hidden_idss()
{
return Base::get_upper_dimension_hidden_idss();
}
CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_hidden_ids()
{
return Base::get_top_dimension_hidden_ids();
}
CK_TILE_HOST_DEVICE static constexpr bool is_static()
{
return Base::is_known_at_compile_time() &&
ck_tile::is_known_at_compile_time<ElementSpaceSize>::value;
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() { return is_static(); }
CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_safe_vector_length_strides()
{
return Base::get_top_dimension_safe_vector_length_strides(
to_array<index_t, ndim_hidden_>(GuaranteedVectorLengths{}),
to_array<index_t, ndim_hidden_>(GuaranteedVectorStrides{}));
}
CK_TILE_HOST_DEVICE void print() const
{
printf("tensor_descriptor{");
// tensor_adaptor
Base::print();
printf(", ");
// element_space_size_
printf("element_space_size_: ");
print(element_space_size_);
printf("}");
}
// TODO make these private
ElementSpaceSize element_space_size_;
};
template <typename Adaptor, typename ElementSpaceSize>
CK_TILE_HOST_DEVICE constexpr auto
make_tensor_descriptor_from_adaptor(const Adaptor& adaptor,
const ElementSpaceSize& element_space_size)
{
constexpr index_t NDimHidden = Adaptor::get_num_of_hidden_dimension();
return tensor_descriptor<remove_cvref_t<decltype(adaptor.get_transforms())>,
remove_cvref_t<decltype(adaptor.get_lower_dimension_hidden_idss())>,
remove_cvref_t<decltype(adaptor.get_upper_dimension_hidden_idss())>,
remove_cvref_t<decltype(adaptor.get_top_dimension_hidden_ids())>,
remove_cvref_t<decltype(element_space_size)>,
typename uniform_sequence_gen<NDimHidden, -1>::type,
typename uniform_sequence_gen<NDimHidden, -1>::type>{
adaptor, element_space_size};
}
template <typename OldTensorDescriptor,
typename NewTransforms,
typename NewLowerDimensionOldTopIdss,
typename NewUpperDimensionNewTopIdss>
CK_TILE_HOST_DEVICE constexpr auto
transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
const NewTransforms& new_transforms,
NewLowerDimensionOldTopIdss,
NewUpperDimensionNewTopIdss)
{
const auto element_space_size = old_tensor_desc.get_element_space_size();
const auto new_tensor_adaptor = transform_tensor_adaptor(old_tensor_desc,
new_transforms,
NewLowerDimensionOldTopIdss{},
NewUpperDimensionNewTopIdss{});
constexpr index_t NDimHiddenOld = OldTensorDescriptor::get_num_of_hidden_dimension();
constexpr index_t NDimHiddenNew = decltype(new_tensor_adaptor)::get_num_of_hidden_dimension();
using NewGuaranteedVectorLengths = typename sequence_merge<
typename OldTensorDescriptor::GuaranteedVectorLengths,
typename uniform_sequence_gen<NDimHiddenNew - NDimHiddenOld, -1>::type>::type;
using NewGuaranteedVectorStrides = typename sequence_merge<
typename OldTensorDescriptor::GuaranteedVectorStrides,
typename uniform_sequence_gen<NDimHiddenNew - NDimHiddenOld, -1>::type>::type;
return tensor_descriptor<
remove_cvref_t<decltype(new_tensor_adaptor.get_transforms())>,
remove_cvref_t<decltype(new_tensor_adaptor.get_lower_dimension_hidden_idss())>,
remove_cvref_t<decltype(new_tensor_adaptor.get_upper_dimension_hidden_idss())>,
remove_cvref_t<decltype(new_tensor_adaptor.get_top_dimension_hidden_ids())>,
remove_cvref_t<decltype(element_space_size)>,
NewGuaranteedVectorLengths,
NewGuaranteedVectorStrides>{new_tensor_adaptor, element_space_size};
}
namespace detail {
template <typename Lengths, typename Strides, index_t I, typename AccOld>
CK_TILE_HOST_DEVICE constexpr auto calculate_element_space_size_impl(const Lengths& lengths,
const Strides& strides,
number<I> i,
AccOld acc_old)
{
auto acc_new = acc_old + (lengths[i] - number<1>{}) * strides[i];
if constexpr(i.value < Lengths::size() - 1)
{
return calculate_element_space_size_impl(lengths, strides, i + number<1>{}, acc_new);
}
else
{
return acc_new;
}
}
} // namespace detail
/*
* These functions create naive tensor descriptor
*/
// Lengths..., Strides... could be:
// 1) index_t, which is known at run-time, or
// 2) number<>, which is known at compile-time
// element_space_size could be:
// 1) long_index_t, or
// 2) long_number<>
template <typename... Lengths,
typename... Strides,
index_t GuaranteedLastDimensionVectorLength = -1,
index_t GuaranteedLastDimensionVectorStride = -1,
typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
CK_TILE_HOST_DEVICE constexpr auto
make_naive_tensor_descriptor(const tuple<Lengths...>& lengths,
const tuple<Strides...>& strides,
number<GuaranteedLastDimensionVectorLength> = number<-1>{},
number<GuaranteedLastDimensionVectorStride> = number<-1>{})
{
constexpr index_t N = sizeof...(Lengths);
const auto transforms = make_tuple(make_embed_transform(lengths, strides));
constexpr auto low_dim_hidden_idss = make_tuple(sequence<0>{});
constexpr auto up_dim_hidden_idss =
make_tuple(typename arithmetic_sequence_gen<1, N + 1, 1>::type{});
constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{};
const auto element_space_size =
detail::calculate_element_space_size_impl(lengths, strides, number<0>{}, long_number<1>{});
using GuaranteedVectorLengths =
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type,
sequence<GuaranteedLastDimensionVectorLength>>::type;
using GuaranteedVectorStrides =
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type,
sequence<GuaranteedLastDimensionVectorStride>>::type;
return tensor_descriptor<remove_cv_t<decltype(transforms)>,
remove_cv_t<decltype(low_dim_hidden_idss)>,
remove_cv_t<decltype(up_dim_hidden_idss)>,
remove_cv_t<decltype(visible_dim_hidden_ids)>,
remove_cv_t<decltype(element_space_size)>,
GuaranteedVectorLengths,
GuaranteedVectorStrides>{transforms, element_space_size};
}
// tensor descriptor with offset, the offset will not be added into element space size
// only have an information of the starting offset, and will impact on offset calculation
template <typename... Lengths,
typename... Strides,
typename offset,
index_t GuaranteedLastDimensionVectorLength = -1,
index_t GuaranteedLastDimensionVectorStride = -1,
typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
CK_TILE_HOST_DEVICE constexpr auto
make_naive_tensor_descriptor_with_offset(const tuple<Lengths...>& lengths,
const tuple<Strides...>& strides,
const offset& os,
number<GuaranteedLastDimensionVectorLength> = number<-1>{},
number<GuaranteedLastDimensionVectorStride> = number<-1>{})
{
const auto desc_0 = [&]() {
const auto element_space_size = detail::calculate_element_space_size_impl(
lengths, strides, number<0>{}, long_number<1>{});
const auto transforms = make_tuple(make_offset_transform(element_space_size, os));
constexpr auto low_dim_hidden_idss = make_tuple(sequence<0>{});
constexpr auto up_dim_hidden_idss = make_tuple(sequence<1>{});
constexpr auto visible_dim_hidden_ids = sequence<1>{};
using GuaranteedVectorLengths =
typename sequence_merge<typename uniform_sequence_gen<1, -1>::type,
sequence<GuaranteedLastDimensionVectorLength>>::type;
using GuaranteedVectorStrides =
typename sequence_merge<typename uniform_sequence_gen<1, -1>::type,
sequence<GuaranteedLastDimensionVectorStride>>::type;
return tensor_descriptor<remove_cv_t<decltype(transforms)>,
remove_cv_t<decltype(low_dim_hidden_idss)>,
remove_cv_t<decltype(up_dim_hidden_idss)>,
remove_cv_t<decltype(visible_dim_hidden_ids)>,
remove_cv_t<decltype(element_space_size)>,
GuaranteedVectorLengths,
GuaranteedVectorStrides>{transforms, element_space_size};
}();
constexpr index_t N = sizeof...(Lengths);
return transform_tensor_descriptor(
desc_0,
make_tuple(make_embed_transform(lengths, strides)),
make_tuple(sequence<0>{}),
make_tuple(typename arithmetic_sequence_gen<0, N, 1>::type{}));
}
// Lengths... could be:
// 1) index_t, which is known at run-time, or
// 2) number<>, which is known at compile-time
// element_space_size could be:
// 1) long_index_t, or
// 2) long_number<>
template <typename... Lengths, index_t GuaranteedLastDimensionVectorLength = -1>
CK_TILE_HOST_DEVICE constexpr auto
make_naive_tensor_descriptor_packed(const tuple<Lengths...>& lengths,
number<GuaranteedLastDimensionVectorLength> = number<-1>{})
{
constexpr index_t N = sizeof...(Lengths);
const auto transforms = make_tuple(make_unmerge_transform(lengths));
constexpr auto low_dim_hidden_idss = make_tuple(sequence<0>{});
constexpr auto up_dim_hidden_idss =
make_tuple(typename arithmetic_sequence_gen<1, N + 1, 1>::type{});
constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{};
const auto element_space_size = container_reduce(lengths, multiplies{}, long_number<1>{});
using GuaranteedVectorLengths =
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type,
sequence<GuaranteedLastDimensionVectorLength>>::type;
using GuaranteedVectorStrides =
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type, sequence<1>>::type;
return tensor_descriptor<remove_cv_t<decltype(transforms)>,
remove_cv_t<decltype(low_dim_hidden_idss)>,
remove_cv_t<decltype(up_dim_hidden_idss)>,
remove_cv_t<decltype(visible_dim_hidden_ids)>,
remove_cv_t<decltype(element_space_size)>,
GuaranteedVectorLengths,
GuaranteedVectorStrides>{transforms, element_space_size};
}
template <typename... Lengths,
typename... Strides,
typename Offset,
index_t GuaranteedLastDimensionVectorLength = -1,
typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor_packed_with_offset(
const tuple<Lengths...>& lengths,
const Offset& offset,
number<GuaranteedLastDimensionVectorLength> = number<-1>{})
{
const auto desc_0 = [&]() {
const auto element_space_size = container_reduce(lengths, multiplies{}, long_number<1>{});
const auto transforms = make_tuple(make_offset_transform(element_space_size, offset));
constexpr auto low_dim_hidden_idss = make_tuple(sequence<0>{});
constexpr auto up_dim_hidden_idss = make_tuple(sequence<1>{});
constexpr auto visible_dim_hidden_ids = sequence<1>{};
using GuaranteedVectorLengths =
typename sequence_merge<typename uniform_sequence_gen<1, -1>::type,
sequence<GuaranteedLastDimensionVectorLength>>::type;
using GuaranteedVectorStrides =
typename sequence_merge<typename uniform_sequence_gen<1, -1>::type, sequence<1>>::type;
return tensor_descriptor<remove_cv_t<decltype(transforms)>,
remove_cv_t<decltype(low_dim_hidden_idss)>,
remove_cv_t<decltype(up_dim_hidden_idss)>,
remove_cv_t<decltype(visible_dim_hidden_ids)>,
remove_cv_t<decltype(element_space_size)>,
GuaranteedVectorLengths,
GuaranteedVectorStrides>{transforms, element_space_size};
}();
constexpr index_t N = sizeof...(Lengths);
return transform_tensor_descriptor(
desc_0,
make_tuple(make_unmerge_transform(lengths)),
make_tuple(sequence<0>{}),
make_tuple(typename arithmetic_sequence_gen<0, N, 1>::type{}));
}
// Lengths... could be:
// 1) index_t, which is known at run-time, or
// 2) number<>, which is known at compile-time
// align could be:
// 1) index_t, or
// 2) number<>
template <typename... Lengths, typename Align>
CK_TILE_HOST_DEVICE constexpr auto
make_naive_tensor_descriptor_aligned(const tuple<Lengths...>& lengths, Align align)
{
constexpr auto I1 = number<1>{};
constexpr index_t N = sizeof...(Lengths);
const auto stride_n_minus_2 = integer_least_multiple(lengths[number<N - 1>{}], align);
auto strides = generate_tuple(
[&](auto i) {
if constexpr(i.value == N - 1)
{
return I1;
}
else if constexpr(i.value == N - 2)
{
return number<stride_n_minus_2>{};
}
else
{
return container_reduce(
lengths, multiplies{}, number<stride_n_minus_2>{}, i + I1, number<N - 1>{}, I1);
}
},
number<N>{});
return make_naive_tensor_descriptor(lengths, strides);
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/tensor/tensor_descriptor.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
template <typename BufferView_, typename TensorDesc_>
struct tensor_view
{
using buffer_view = remove_reference_t<BufferView_>;
using DataType = typename buffer_view::type;
using TensorDesc = remove_cvref_t<TensorDesc_>;
using TensorIndex = array<index_t, TensorDesc::get_num_of_top_dimension()>;
using TensorCoord = decltype(make_tensor_coordinate(TensorDesc{}, TensorIndex{}));
CK_TILE_HOST_DEVICE constexpr tensor_view() = default;
CK_TILE_HOST_DEVICE constexpr tensor_view(const buffer_view& buffer_view,
const TensorDesc& desc)
: buf_{buffer_view}, desc_{desc}
{
}
CK_TILE_HOST_DEVICE constexpr auto& get_tensor_descriptor() const { return desc_; }
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension()
{
return TensorDesc::get_num_of_top_dimension();
}
CK_TILE_HOST_DEVICE constexpr const auto& get_buffer_view() const { return buf_; }
CK_TILE_HOST_DEVICE constexpr auto& get_buffer_view() { return buf_; }
#if 0
CK_TILE_HOST_DEVICE constexpr DataType get_element(const TensorCoord& coord) const
{
return buf_.template get<DataType>(
coord.get_offset(),
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord));
}
CK_TILE_HOST_DEVICE constexpr void set_element(const TensorCoord& coord, const DataType& x)
{
buf_.template set<DataType>(
coord.get_offset(),
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
x);
}
#endif
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<X>
get_vectorized_elements(const TensorCoord& coord,
bool_constant<oob_conditional_check> = {}) const
{
return buf_.template get<X>(
coord.get_offset(),
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
bool_constant<oob_conditional_check>{});
}
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE void
get_vectorized_elements_raw(remove_cvref_t<X>& dst,
const TensorCoord& coord,
bool_constant<oob_conditional_check> = {}) const
{
return buf_.template get_raw<X, oob_conditional_check>(
dst,
coord.get_offset(),
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord));
}
template <typename X,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements(remove_cvref_t<DataType>* smem,
const TensorCoord& coord) const
{
return buf_.template async_get<X>(smem, coord.get_offset(), true /*not used*/);
}
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void set_vectorized_elements(
const TensorCoord& coord, const X& x, bool_constant<oob_conditional_check> = {})
{
buf_.template set<X, oob_conditional_check>(
coord.get_offset(),
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
x);
}
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void set_vectorized_elements_raw(
const TensorCoord& coord, const X& x, bool_constant<oob_conditional_check> = {})
{
buf_.template set_raw<X, oob_conditional_check>(
coord.get_offset(),
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
x);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("tensor_view{");
// buf_
printf("buf_: ");
print(buf_);
printf(", ");
// desc_
printf("desc_: ");
print(desc_);
printf("}");
}
// member
buffer_view buf_;
TensorDesc desc_;
};
// placeholder type if we want to opt-out a tile view parameter
struct null_tensor_view
{
};
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
typename DataType,
typename... Ts>
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* p,
const tensor_descriptor<Ts...>& desc)
{
auto buffer_view = make_buffer_view<BufferAddressSpace>(p, desc.get_element_space_size());
return tensor_view<decltype(buffer_view), decltype(desc)>{buffer_view, desc};
}
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
typename DataType,
typename... Lengths,
typename... Strides,
index_t GuaranteedLastDimensionVectorLength = -1,
index_t GuaranteedLastDimensionVectorStride = -1,
typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
CK_TILE_HOST_DEVICE constexpr auto
make_naive_tensor_view(DataType* p,
const tuple<Lengths...>& lengths,
const tuple<Strides...>& strides,
number<GuaranteedLastDimensionVectorLength> = number<-1>{},
number<GuaranteedLastDimensionVectorStride> = number<-1>{})
{
auto desc = make_naive_tensor_descriptor(lengths,
strides,
number<GuaranteedLastDimensionVectorLength>{},
number<GuaranteedLastDimensionVectorStride>{});
auto buffer_view = make_buffer_view<BufferAddressSpace>(p, desc.get_element_space_size());
return tensor_view<decltype(buffer_view), decltype(desc)>{buffer_view, desc};
}
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
typename DataType,
typename... Lengths,
index_t GuaranteedLastDimensionVectorLength = -1>
CK_TILE_HOST_DEVICE constexpr auto
make_naive_tensor_view_packed(DataType* p,
const tuple<Lengths...>& lengths,
number<GuaranteedLastDimensionVectorLength> = number<-1>{})
{
auto desc =
make_naive_tensor_descriptor_packed(lengths, number<GuaranteedLastDimensionVectorLength>{});
auto buffer_view = make_buffer_view<BufferAddressSpace>(p, desc.get_element_space_size());
return tensor_view<decltype(buffer_view), decltype(desc)>{buffer_view, desc};
}
template <typename OldTensorView,
typename NewTransforms,
typename NewLowerDimensionOldVisibleIdss,
typename NewUpperDimensionNewVisibleIdss>
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_view(const OldTensorView& old_tensor_view,
const NewTransforms& new_transforms,
NewLowerDimensionOldVisibleIdss,
NewUpperDimensionNewVisibleIdss)
{
auto new_desc = transform_tensor_descriptor(old_tensor_view.desc_,
new_transforms,
NewLowerDimensionOldVisibleIdss{},
NewUpperDimensionNewVisibleIdss{});
return tensor_view<typename OldTensorView::buffer_view, remove_cvref_t<decltype(new_desc)>>{
old_tensor_view.buf_, new_desc};
}
template <typename TensorView,
typename TileLengths, // tuple<...>
typename DoPads> // sequence<bool, bool, ...>
CK_TILE_HOST_DEVICE constexpr auto
pad_tensor_view(const TensorView& tensor_view, const TileLengths& tile_lengths, DoPads)
{
constexpr index_t num_dim = DoPads::size();
static_assert(num_dim == TileLengths::size() && num_dim == TensorView::get_num_of_dimension(),
"wrong! inconsistent # of dimensions");
// transforms
const auto transforms = generate_tuple(
[&](auto idim) {
const auto old_length = tensor_view.get_tensor_descriptor().get_length(idim);
const auto tile_length = tile_lengths[idim];
const auto new_length = integer_divide_ceil(old_length, tile_length) * tile_length;
const auto pad_length = new_length - old_length;
constexpr bool DoPad = DoPads::at(idim);
const auto transform =
conditional_expr<DoPad>(make_right_pad_transform(old_length, pad_length),
make_pass_through_transform(old_length));
return transform;
},
number<num_dim>{});
// lower dimension Id
const auto lower_dimss =
generate_tuple([&](auto idim) { return sequence<idim.value>{}; }, number<num_dim>{});
// upper dimension Id
const auto upper_dimss = lower_dimss;
return transform_tensor_view(tensor_view, transforms, lower_dimss, upper_dimss);
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
// distributed span
template <index_t... PartialHsLengths>
struct tile_distributed_span
{
using Impl = sequence<PartialHsLengths...>;
static constexpr auto impl_ = Impl{};
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; }
};
// distributed index
template <index_t... PartialHsIndices>
struct tile_distributed_index
{
using Impl = sequence<PartialHsIndices...>;
static constexpr auto impl_ = Impl{};
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; }
};
namespace detail {
template <index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto make_tile_distributed_span(sequence<Is...>)
{
return tile_distributed_span<Is...>{};
}
template <index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto make_tile_distributed_index(sequence<Is...>)
{
return tile_distributed_index<Is...>{};
}
} // namespace detail
template <typename PsYs2XsAdaptor_,
typename Ys2DDescriptor_,
typename StaticTileDistributionEncoding_,
typename TileDistributionDetail_> // FIXME: this is for hold ad-hoc but useful info,
// should be more elegnat
struct tile_distribution
{
using PsYs2XsAdaptor = remove_cvref_t<PsYs2XsAdaptor_>;
using Ys2DDescriptor = remove_cvref_t<Ys2DDescriptor_>;
using DstrEncode = remove_cvref_t<StaticTileDistributionEncoding_>;
using DstrDetail = remove_cvref_t<TileDistributionDetail_>;
static_assert(PsYs2XsAdaptor::is_static() && Ys2DDescriptor::is_static(),
"wrong! should be static");
static constexpr index_t NDimX = PsYs2XsAdaptor::get_num_of_bottom_dimension();
static constexpr index_t NDimY = Ys2DDescriptor::get_num_of_top_dimension();
static constexpr index_t NDimP = PsYs2XsAdaptor::get_num_of_top_dimension() - NDimY;
static constexpr index_t NDimR = StaticTileDistributionEncoding_::NDimR;
PsYs2XsAdaptor ps_ys_to_xs_;
Ys2DDescriptor ys_to_d_;
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_x() { return NDimX; }
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_y() { return NDimY; }
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_p() { return NDimP; }
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_r() { return NDimR; }
CK_TILE_HOST_DEVICE static constexpr auto get_lengths()
{
#if 0
// FIXME: tensor_adaptor::GetBottomDimensionLengths is wrong. re-enable this after it's fixed
ps_ys_to_xs_.GetBottomDimensionLengths();
#else
return generate_tuple(
[&](auto i) {
constexpr index_t x_length =
container_reduce(typename DstrEncode::HsLengthss{}[i], multiplies{}, 1);
return number<x_length>{};
},
number<NDimX>{});
#endif
}
CK_TILE_HOST_DEVICE constexpr const auto& get_ps_ys_to_xs_adaptor() const
{
return ps_ys_to_xs_;
}
CK_TILE_HOST_DEVICE constexpr const auto& get_ys_to_d_descriptor() const { return ys_to_d_; }
CK_TILE_HOST_DEVICE static constexpr auto get_static_tile_distribution_encoding()
{
return DstrEncode{};
}
#if 1
// Calculate Replication index [R0, R1, ...] based on Partion index
// FIXME: very nasty implementation
template <typename PartitionIndex>
CK_TILE_HOST_DEVICE auto calculate_rs_index_from_ps_index(const PartitionIndex& ps_idx) const
{
static_assert(PartitionIndex::size() == NDimP, "wrong!");
const auto ps_ys_idx = container_concat(ps_idx, array<index_t, NDimY>{0});
const auto dummy_adaptor_coord = make_tensor_adaptor_coordinate(ps_ys_to_xs_, ps_ys_idx);
array<index_t, NDimR> rs_idx;
static_for<0, NDimP, 1>{}([&](auto idim_p) {
constexpr index_t ndim_low = DstrEncode::ps_to_rhss_major_[idim_p].size();
static_for<0, ndim_low, 1>{}([&](auto i) {
constexpr index_t rh_major = DstrEncode::ps_to_rhss_major_[idim_p][i];
constexpr index_t rh_minor = DstrEncode::ps_to_rhss_minor_[idim_p][i];
// 0-th rh_major is the replicate dimension
if constexpr(rh_major == 0)
{
constexpr index_t adaptor_hidden_id =
DstrDetail::rh_major_minor_to_adaptor_hidden_idss_[rh_major][rh_minor];
// fill in
rs_idx(rh_minor) = dummy_adaptor_coord.get_hidden_index()[adaptor_hidden_id];
}
});
});
return rs_idx;
}
#endif
CK_TILE_HOST_DEVICE static constexpr auto get_distributed_spans()
{
constexpr auto distributed_spans_impl = DstrEncode::detail::distributed_spans_lengthss_;
constexpr auto ndims_spans_minor = DstrEncode::detail::ndims_distributed_spans_minor_;
return generate_tuple(
[&](auto i) {
constexpr auto span_impl = distributed_spans_impl[i];
constexpr index_t ndim_span_minor = ndims_spans_minor[i];
constexpr auto span = TO_SEQUENCE(span_impl, ndim_span_minor);
return detail::make_tile_distributed_span(span);
},
number<NDimX>{});
}
// FIXME: it's hacky to get Y index from Distributed-Index
template <typename DistributedIndices>
CK_TILE_HOST_DEVICE static constexpr auto
get_y_indices_from_distributed_indices(DistributedIndices)
{
constexpr auto ys_idx_arr = [] {
array<index_t, NDimY> ys_idx;
static_for<0, NDimY, 1>{}([&](auto i) {
constexpr index_t span_major = DstrEncode::detail::ys_to_span_major_[i];
constexpr index_t span_minor = DstrEncode::detail::ys_to_span_minor_[i];
constexpr auto dstr_index = DistributedIndices{}[number<span_major>{}];
ys_idx(i) = dstr_index.impl_[span_minor];
});
return ys_idx;
}();
constexpr index_t ndim_y = NDimY;
return TO_SEQUENCE(ys_idx_arr, ndim_y);
}
CK_TILE_HOST_DEVICE static constexpr bool is_static()
{
return PsYs2XsAdaptor::is_static() && Ys2DDescriptor::is_static();
}
CK_TILE_HOST_DEVICE void print() const
{
printf("tile_distribution{");
//
printf("tile_distribution_encoding: ");
print(DstrEncode{});
printf(", ");
//
printf("ps_ys_to_xs_: ");
print(ps_ys_to_xs_);
printf(", ");
//
printf("ys_to_d_: ");
print(ys_to_d_);
//
printf("}");
}
};
namespace detail {
template <index_t NDimMax>
CK_TILE_HOST_DEVICE constexpr auto make_sequential_index(index_t ibegin, index_t iend)
{
array<index_t, NDimMax> arr{0};
for(index_t i = 0; i < iend - ibegin; ++i)
{
arr(i) = ibegin + i;
}
return arr;
}
// this returns a constexpr encoding of tile_distribution
template <typename StaticTileDistributionEncoding_>
CK_TILE_HOST_DEVICE constexpr auto
make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_)
{
using RsLengths = typename StaticTileDistributionEncoding_::RsLengths;
using HsLengthss = typename StaticTileDistributionEncoding_::HsLengthss;
using Ps2RHssMajor = typename StaticTileDistributionEncoding_::Ps2RHssMajor;
using Ps2RHssMinor = typename StaticTileDistributionEncoding_::Ps2RHssMinor;
using Ys2RHsMajor = typename StaticTileDistributionEncoding_::Ys2RHsMajor;
using Ys2RHsMinor = typename StaticTileDistributionEncoding_::Ys2RHsMinor;
// FIXME: increase max value if fail
constexpr index_t kMaxNumTransforms = 20;
constexpr index_t kMaxMetaDataSize = 128;
constexpr index_t kMaxNumDim = 10;
using Name = coord_transform_enum;
using MetaData = meta_data_buffer<kMaxMetaDataSize>;
using NumDim = index_t;
using Dims = array<index_t, kMaxNumDim>;
using Lengths = array<index_t, kMaxNumDim>;
// Tile Adaptor
// bottom dims [x0, x1, x2, ...]
// top dims [p0, p1, ..., y0, y1, ...]
constexpr index_t ndim_x = HsLengthss::size();
// Dim Ids: [idim_x_major, idim_x_minor] to [idim_hidden]
array<array<index_t, kMaxNumDim>, ndim_x + 1> rh_major_minor_to_hidden_ids;
array<array<index_t, kMaxNumDim>, ndim_x + 1> rh_major_minor_to_hidden_lengths;
auto trans = array<tuple<Name, MetaData, NumDim, Dims, NumDim, Dims>, kMaxNumTransforms>{};
index_t num_tran = 0;
index_t hidden_dim_cnt = ndim_x;
// this is replicate transform
{
constexpr index_t ndim_r_minor = RsLengths::size();
constexpr auto r_minor_lengths = RsLengths{};
trans(num_tran++) = {
coord_transform_enum::replicate,
MetaData{to_array<index_t, ndim_r_minor>(r_minor_lengths)},
NumDim{0},
Dims{},
NumDim{ndim_r_minor},
make_sequential_index<kMaxNumDim>(hidden_dim_cnt, hidden_dim_cnt + ndim_r_minor)};
for(index_t i = 0; i < ndim_r_minor; ++i)
{
rh_major_minor_to_hidden_ids(0)(i) = hidden_dim_cnt;
rh_major_minor_to_hidden_lengths(0)(i) = r_minor_lengths[i];
hidden_dim_cnt++;
}
};
// these are Unmerge transforms for X dimesions
static_for<0, ndim_x, 1>{}([&trans,
&num_tran,
&hidden_dim_cnt,
&rh_major_minor_to_hidden_ids,
&rh_major_minor_to_hidden_lengths](auto idim_x) {
// typename HsLengthss::base{}.foo();
constexpr auto h_minor_lengths =
HsLengthss{}.get(idim_x); // std::tuple_element_t<idim_x, HsLengthss>{};
// constexpr auto h_minor_lengths = impl::getv<idim_x>(HsLengthss{});
constexpr index_t ndim_h_minor = h_minor_lengths.size();
trans(num_tran++) = {
coord_transform_enum::unmerge,
MetaData{to_array<index_t, ndim_h_minor>(h_minor_lengths)},
NumDim{1},
Dims{idim_x},
NumDim{ndim_h_minor},
make_sequential_index<kMaxNumDim>(hidden_dim_cnt, hidden_dim_cnt + ndim_h_minor)};
for(index_t i = 0; i < ndim_h_minor; ++i)
{
rh_major_minor_to_hidden_ids(idim_x + 1)(i) = hidden_dim_cnt;
rh_major_minor_to_hidden_lengths(idim_x + 1)(i) = h_minor_lengths[i];
hidden_dim_cnt++;
}
});
// transform: P dimensions
constexpr index_t ndim_p = Ps2RHssMajor::size();
Dims hidden_dim_id_ps;
static_for<0, ndim_p, 1>{}([&](auto iDimP) {
//
index_t hidden_dim_id_p = hidden_dim_cnt++;
hidden_dim_id_ps(iDimP) = hidden_dim_id_p;
constexpr auto p2RHsMajor = Ps2RHssMajor{}[iDimP];
constexpr auto p2RHsMinor = Ps2RHssMinor{}[iDimP];
static_assert(p2RHsMajor.size() == p2RHsMinor.size(), "wrong!");
constexpr index_t ndim_low = p2RHsMajor.size();
Dims low_dims;
Lengths low_lengths;
for(index_t i = 0; i < ndim_low; ++i)
{
index_t rh_major = p2RHsMajor[i];
index_t rh_minor = p2RHsMinor[i];
low_dims(i) = rh_major_minor_to_hidden_ids[rh_major][rh_minor];
low_lengths(i) = rh_major_minor_to_hidden_lengths[rh_major][rh_minor];
}
trans(num_tran++) = {coord_transform_enum::merge,
MetaData{to_array<index_t, ndim_low>(low_lengths)},
NumDim{ndim_low},
low_dims,
NumDim{1},
Dims{hidden_dim_id_p}};
});
constexpr index_t ndim_bottom = ndim_x;
constexpr auto bottom_dim_ids = make_sequential_index<kMaxNumDim>(0, ndim_bottom);
constexpr auto ys_to_rhs_major = Ys2RHsMajor{};
constexpr auto ys_to_rhs_minor = Ys2RHsMinor{};
constexpr index_t ndim_y = Ys2RHsMajor::size();
constexpr index_t ndim_top = ndim_p + ndim_y;
auto top_dim_ids = hidden_dim_id_ps;
{
for(index_t i = 0; i < ndim_y; ++i)
{
index_t rh_major = ys_to_rhs_major[i];
index_t rh_minor = ys_to_rhs_minor[i];
top_dim_ids(ndim_p + i) = rh_major_minor_to_hidden_ids[rh_major][rh_minor];
}
}
//
const auto ps_ys_to_xs_adaptor_encoding =
make_tuple(trans, num_tran, bottom_dim_ids, ndim_bottom, top_dim_ids, ndim_top);
// descriptor: [y0, y1, ...] to [d]
Lengths y_lengths;
index_t d_length = 1;
for(index_t i = 0; i < ndim_y; ++i)
{
index_t rh_major = ys_to_rhs_major[i];
index_t rh_minor = ys_to_rhs_minor[i];
index_t y_length = rh_major_minor_to_hidden_lengths[rh_major][rh_minor];
y_lengths(i) = y_length;
d_length *= y_length;
}
auto tran = make_tuple(coord_transform_enum::unmerge,
MetaData{to_array<index_t, ndim_y>(y_lengths)},
NumDim{1},
Dims{0},
NumDim{ndim_y},
make_sequential_index<kMaxNumDim>(1, ndim_y + 1));
const auto ys_to_d_adaptor_encoding = make_tuple(
make_tuple(tran), 1, Dims{0}, 1, make_sequential_index<kMaxNumDim>(1, ndim_y + 1), ndim_y);
return make_tuple(ps_ys_to_xs_adaptor_encoding,
ys_to_d_adaptor_encoding,
d_length,
rh_major_minor_to_hidden_ids);
}
// FIXME: this is nasty. Move it inside TileDistributionEncoding::detail
template <typename RhMajorMinor2AdaptorHiddenIdss> // tuple<sequence<...>, ...>
struct tile_distribution_detail
{
static constexpr auto rh_major_minor_to_adaptor_hidden_idss_ =
to_array_of_array(RhMajorMinor2AdaptorHiddenIdss{});
};
} // namespace detail
// this returns a constexpr tile_distribution
template <typename StaticTileDistributionEncoding_>
CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistributionEncoding_)
{
using DstrEncode = remove_cvref_t<StaticTileDistributionEncoding_>;
constexpr auto adaptor_impl =
detail::make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_{});
constexpr auto ps_ys_to_xs_adaptor_impl = adaptor_impl.template at<0>();
constexpr auto ys_to_d_adaptor_impl = adaptor_impl.template at<1>();
constexpr index_t d_length = adaptor_impl.template at<2>();
constexpr auto rh_major_minor_to_hidden_ids_impl = adaptor_impl.template at<3>();
constexpr auto ps_ys_to_xs_adaptor =
CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(ps_ys_to_xs_adaptor_impl);
constexpr auto ys_to_d_adaptor = CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(ys_to_d_adaptor_impl);
constexpr auto ys_to_d_descriptor =
make_tensor_descriptor_from_adaptor(ys_to_d_adaptor, d_length);
//
constexpr index_t ndim_rh_major = DstrEncode::detail::ndim_rh_major_;
constexpr auto ndims_rhs_minor = DstrEncode::detail::ndims_rhs_minor_;
constexpr auto rh_major_minor_to_hidden_ids =
TO_TUPLE_OF_SEQUENCE(rh_major_minor_to_hidden_ids_impl, ndim_rh_major, ndims_rhs_minor);
return tile_distribution<
remove_cvref_t<decltype(ps_ys_to_xs_adaptor)>,
remove_cvref_t<decltype(ys_to_d_descriptor)>,
remove_cvref_t<DstrEncode>,
detail::tile_distribution_detail<remove_cvref_t<decltype(rh_major_minor_to_hidden_ids)>>>{
ps_ys_to_xs_adaptor, ys_to_d_descriptor};
}
// this returns a static tile_distribution
template <typename StaticTileDistributionEncoding_>
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
{
using DstrEncode = remove_cvref_t<StaticTileDistributionEncoding_>;
constexpr auto adaptor_impl =
detail::make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_{});
constexpr auto ps_ys_to_xs_adaptor_impl = adaptor_impl.template at<0>();
constexpr auto ys_to_d_adaptor_impl = adaptor_impl.template at<1>();
constexpr index_t d_length = adaptor_impl.template at<2>();
constexpr auto rh_major_minor_to_hidden_ids_impl = adaptor_impl.template at<3>();
constexpr auto ps_ys_to_xs_adaptor =
CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(ps_ys_to_xs_adaptor_impl);
constexpr auto ys_to_d_adaptor =
CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(ys_to_d_adaptor_impl);
constexpr auto ys_to_d_descriptor =
make_tensor_descriptor_from_adaptor(ys_to_d_adaptor, number<d_length>{});
//
constexpr index_t ndim_rh_major = DstrEncode::detail::ndim_rh_major_;
constexpr auto ndims_rhs_minor = DstrEncode::detail::ndims_rhs_minor_;
constexpr auto rh_major_minor_to_hidden_ids =
TO_TUPLE_OF_SEQUENCE(rh_major_minor_to_hidden_ids_impl, ndim_rh_major, ndims_rhs_minor);
return tile_distribution<
remove_cvref_t<decltype(ps_ys_to_xs_adaptor)>,
remove_cvref_t<decltype(ys_to_d_descriptor)>,
remove_cvref_t<DstrEncode>,
detail::tile_distribution_detail<remove_cvref_t<decltype(rh_major_minor_to_hidden_ids)>>>{
ps_ys_to_xs_adaptor, ys_to_d_descriptor};
}
//***********************************************************************************
namespace detail {
template <typename Distribution>
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
{
// only support warp-tile and block-tile
static_assert(Distribution::NDimP == 1 or Distribution::NDimP == 2, "wrong!");
if constexpr(Distribution::NDimP == 1)
{
return array<index_t, 1>{get_lane_id()};
}
else if constexpr(Distribution::NDimP == 2)
{
return array<index_t, 2>{get_warp_id(), get_lane_id()};
}
}
template <typename, typename, typename, index_t>
struct reverse_slice_sequence_impl;
template <index_t x,
index_t... xs,
index_t m,
index_t... ms,
index_t id,
index_t... ids,
index_t SliceSize>
struct reverse_slice_sequence_impl<sequence<x, xs...>,
sequence<m, ms...>,
sequence<id, ids...>,
SliceSize>
{
using old_scan =
reverse_slice_sequence_impl<sequence<xs...>, sequence<ms...>, sequence<ids...>, SliceSize>;
static constexpr auto slice_size = old_scan::remaining_slice_sizes::front().value;
static constexpr auto slice_length =
std::conditional_t<m, number<gcd(x, slice_size)>, number<x>>::value;
using dim_lengths =
typename sequence_merge<sequence<slice_length>, typename old_scan::dim_lengths>::type;
using dim_slices =
typename sequence_merge<sequence<x / slice_length>, typename old_scan::dim_slices>::type;
using remaining_slice_sizes = typename sequence_merge<
std::conditional_t<m, sequence<slice_size / slice_length>, sequence<slice_size>>,
typename old_scan::remaining_slice_sizes>::type;
// the first idx that sliced length not equal to original length
static constexpr index_t _flag =
slice_length != x && remaining_slice_sizes{}.front().value == 1;
static constexpr index_t _split_flag = std::conditional_t<m, number<_flag>, number<0>>::value;
static constexpr index_t _split_idx =
std::conditional_t<_split_flag, number<id>, number<0>>::value;
static constexpr index_t split_flag = _split_flag || old_scan::split_flag;
static constexpr index_t split_idx = std::
conditional_t<old_scan::split_flag, number<old_scan::split_idx>, number<_split_idx>>::value;
};
template <index_t x, index_t m, index_t id, index_t SliceSize>
struct reverse_slice_sequence_impl<sequence<x>, sequence<m>, sequence<id>, SliceSize>
{
static constexpr auto slice_size = SliceSize;
static constexpr auto slice_length =
std::conditional_t<m, number<gcd(x, slice_size)>, number<x>>::value;
using dim_lengths = sequence<slice_length>;
using dim_slices = sequence<x / slice_length>;
using remaining_slice_sizes =
std::conditional_t<m, sequence<slice_size / slice_length>, sequence<slice_size>>;
// the first idx that sliced length not equal to original length
static constexpr index_t _flag =
slice_length != x && remaining_slice_sizes{}.front().value == 1;
static constexpr index_t split_flag = std::conditional_t<m, number<_flag>, number<0>>::value;
static constexpr index_t split_idx =
std::conditional_t<split_flag, number<id>, number<0>>::value;
};
// clang-format off
// input a sequence(with optional mask), and the SliceSize : size per slice
// output the sequence each slice, and number of slices
//
// e.g. <2, 1, 4, 2>, 8 -> lengths:<1, 1, 4, 2> , nums: <2, 1, 1, 1> : 2 slices , slice_idx: 0
// <4, 2, 4, 1, 2>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 1> : 16 slices , slice_idx: 2
// <4, 2, 4, 1, 6>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 3> : 48 slices , slice_idx: 2
// <4, 2, 5, 1, 2>, 10 -> lengths:<1, 1, 5, 1, 2> , nums: <4, 2, 1, 1, 1> : 8 slices , slice_idx: 1
//
// <4, 2, 8>, 64 -> lengths:<4, 2, 8> , nums: <1, 1, 1> : 1 slices , slice_idx: 0
// <4, 2, 8>, 32 -> lengths:<2, 2, 8> , nums: <2, 1, 1> : 2 slices , slice_idx: 0
// <4, 2, 8>, 16 -> lengths:<1, 2, 8> , nums: <4, 1, 1> : 4 slices , slice_idx: 0
// <4, 2, 8>, 8 -> lengths:<1, 1, 8> , nums: <4, 2, 1> : 8 slices , slice_idx: 1
// <4, 2, 8>, 4 -> lengths:<1, 1, 4> , nums: <4, 2, 2> : 16 slices , slice_idx: 2
// <4, 2, 8>, 2 -> lengths:<1, 1, 2> , nums: <4, 2, 4> : 32 slices , slice_idx: 2
// <4, 2, 8>, 1 -> lengths:<1, 1, 1> , nums: <4, 2, 8> : 64 slices , slice_idx: 2
//
// <4, 2, 1, 4, 2> / 4 ->
// mask:<1, 1, 1, 0, 1>, -> lengths:<1, 2, 1, 4, 2> , nums: <4, 1, 1, 1, 1> : 8 slices , slice_idx: 0
//
// return tuple<slice_lengths, slice_nums, slice_index>, slice_index is at which index will start
// have split slices (right -> left)
// or the first index that sliced length is different from the original length
// clang-format on
template <typename Seq,
index_t SliceSize,
typename Mask = typename uniform_sequence_gen<Seq::size(), 1>::type>
constexpr auto reverse_slice_sequence(Seq,
number<SliceSize>,
Mask = typename uniform_sequence_gen<Seq::size(), 1>::type{})
{
static_assert(Seq::size() == Mask::size());
using sliced_type =
reverse_slice_sequence_impl<Seq,
Mask,
typename arithmetic_sequence_gen<0, Seq::size(), 1>::type,
SliceSize>;
static_assert(sliced_type::remaining_slice_sizes::front().value == 1,
"can not evenly divide this sequence, please check");
return make_tuple(typename sliced_type::dim_lengths{},
typename sliced_type::dim_slices{},
number<sliced_type::split_idx>{});
}
//
// slice tensor from x_dim, result in split in y_dim, not p_dim.
// We don't support slice cross p_dim (aka, slice different threads)
// also, sliced along y_dim need be the first dim of current dim.
// Multiply Y dim before sliced dim does not make sense
//
// e.g
// X0 X1
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 32>, (0 means all length)
// Y P P Y P Y P Y
// => <1, 4, 32> - <1, 1, 4, 2, 4> -> OK
// |--> slice along this Y dim, is the first dim of X1, totally 4 slices
//
// X0 X1
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 8>, (0 means all length)
// Y P P Y P Y P Y
// => <1, 4, 32> - <1, 1, 1, 2, 4> -> OK
// |--> slice along this Y dim, the P dim is 1 in the left, so is OK
// totally 16 slices
//
// X0 X1
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 4>, (0 means all length)
// Y P P Y P Y P Y
// => <1, 4, 32> - <1, 1, 1, 1, 4> -> Fail
// |--> slice along this P dim, will split threads, not supported
//
// X0 X1
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 16>, (0 means all length)
// Y P P Y P Y P Y
// => <1, 4, 32> - <1, 1, 2, 2, 4> -> OK
// |--> slice along this Y dim, but this Y sim need to split into 2
// subdime
// the P dim in the left is 1, means actually not crossing P
//
template <typename Distribution, index_t... XSliceBegins, index_t... XSliceEnds>
CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
Distribution, sequence<XSliceBegins...> x_slice_begins, sequence<XSliceEnds...> x_slice_ends)
{
// NOTE: this function need to be called under constexpr context,
// due to https://wg21.link/p2280r0 we have to use non-reference type for distribution
using Encoding = decltype(Distribution::get_static_tile_distribution_encoding());
static_assert(sizeof...(XSliceBegins) == sizeof...(XSliceEnds));
constexpr auto x_slice_lengths = x_slice_ends - x_slice_begins;
constexpr auto src_h_prefix_sum = Encoding::detail::get_h_dim_lengths_prefix_sum();
constexpr auto src_y_info = Encoding::detail::get_sorted_y_info();
constexpr auto src_y_dims = src_y_info[number<0>{}];
constexpr auto src_y_maps = src_y_info[number<1>{}];
constexpr auto src_y_prefix_sum = src_y_info[number<2>{}];
constexpr auto sliced_hlen_yidx_ylen = [&]() constexpr
{
auto y_slice_sorted_origins = make_zero_multi_index<Encoding::NDimY>();
auto y_slice_lengths = Encoding::detail::ys_lengths_;
// This lambda will modify some value outside, so c++ will not treat return value as
// constexpr
// TODO: ugly
auto new_h_lengths = transform_tuples(
[&](auto h_len, auto id) {
constexpr auto sliced_h =
reverse_slice_sequence(h_len, number<x_slice_lengths[id]>{});
constexpr auto sliced_h_lens = sliced_h[number<0>{}];
constexpr auto sliced_h_index = sliced_h[number<2>{}];
// update y_slice_lengths
constexpr auto uniformed_h_index = sliced_h_index + number<src_h_prefix_sum[id]>{};
constexpr auto found_y_index = container_find(src_y_dims, uniformed_h_index);
static_assert(found_y_index >= 0 && found_y_index < src_y_dims.size(),
"not sliced at y dim, please check");
static_for<0, sliced_h_index + 1, 1>{}([&](auto i) {
y_slice_lengths(src_y_maps[found_y_index - i]) =
sliced_h_lens[sliced_h_index - i];
});
// TODO: add validations not across p dim
// NOTE: this y_origin is for all dims, not only current dim
// will later use pick to select target dim
constexpr auto y_origin = [&]() {
constexpr auto h_trans = make_merge_transform_v3_division_mod(h_len);
auto h_origin_ = make_zero_multi_index<h_trans.NDimLow>();
h_trans.calculate_lower_index(h_origin_, sequence<x_slice_begins[id].value>{});
auto y_origin_ = make_zero_multi_index<Encoding::NDimY>();
static_for<0, sliced_h_index + 1, 1>{}([&](auto i) {
y_origin_(found_y_index - i) = h_origin_[sliced_h_index - i];
});
return y_origin_;
}();
constexpr auto y_picks = typename arithmetic_sequence_gen<src_y_prefix_sum[id],
src_y_prefix_sum[id + 1],
1>::type{};
set_container_subset(
y_slice_sorted_origins, y_picks, get_container_subset(y_origin, y_picks));
return sliced_h_lens;
},
typename Encoding::HsLengthss{},
typename arithmetic_sequence_gen<0, Encoding::HsLengthss::size(), 1>::type{});
auto y_slice_origins = container_reorder_given_old2new(y_slice_sorted_origins, src_y_maps);
return make_tuple(new_h_lengths, y_slice_origins, y_slice_lengths);
}
();
constexpr auto sliced_h_lengths = sliced_hlen_yidx_ylen[number<0>{}];
constexpr auto sliced_y_origins_array = sliced_hlen_yidx_ylen[number<1>{}];
constexpr auto sliced_y_origins_size = sliced_y_origins_array.size();
constexpr auto sliced_y_lengths_array = sliced_hlen_yidx_ylen[number<2>{}];
constexpr auto sliced_y_lengths_size = sliced_y_lengths_array.size();
constexpr auto sliced_y_origins = TO_SEQUENCE(sliced_y_origins_array, sliced_y_origins_size);
constexpr auto sliced_y_lengths = TO_SEQUENCE(sliced_y_lengths_array, sliced_y_lengths_size);
return make_tuple(
make_static_tile_distribution(
tile_distribution_encoding<typename Encoding::RsLengths,
decltype(sliced_h_lengths), // only need to change the
// h_lengths type
typename Encoding::Ps2RHssMajor,
typename Encoding::Ps2RHssMinor,
typename Encoding::Ys2RHsMajor,
typename Encoding::Ys2RHsMinor>{}),
sliced_y_origins,
sliced_y_lengths);
}
} // namespace detail
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tensor_adaptor_coordinate.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/multi_index.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
template <typename RsLengths_, // sequence<...>
typename HsLengthss_, // tuple<sequence<...>, ...>
typename Ps2RHssMajor_, // tuple<sequence<...>, ...>
typename Ps2RHssMinor_, // tuple<sequence<...>, ...>
typename Ys2RHsMajor_, // sequence<...>
typename Ys2RHsMinor_> // sequence<...>
struct tile_distribution_encoding
{
using RsLengths = remove_cvref_t<RsLengths_>;
using HsLengthss = remove_cvref_t<HsLengthss_>;
using Ps2RHssMajor = remove_cvref_t<Ps2RHssMajor_>;
using Ps2RHssMinor = remove_cvref_t<Ps2RHssMinor_>;
using Ys2RHsMajor = remove_cvref_t<Ys2RHsMajor_>;
using Ys2RHsMinor = remove_cvref_t<Ys2RHsMinor_>;
static_assert(Ps2RHssMajor::size() == Ps2RHssMinor::size(), "wrong!");
static_assert(Ys2RHsMajor::size() == Ys2RHsMinor::size(), "wrong!");
static constexpr index_t NDimX = HsLengthss::size();
static constexpr index_t NDimP = Ps2RHssMajor::size();
static constexpr index_t NDimY = Ys2RHsMajor::size();
static constexpr index_t NDimR = RsLengths::size();
// FIXME: move into detail
static constexpr auto rs_lengths_ = RsLengths{};
static constexpr auto hs_lengthss_ = HsLengthss{};
static constexpr auto ps_to_rhss_major_ = Ps2RHssMajor{};
static constexpr auto ps_to_rhss_minor_ = Ps2RHssMinor{};
static constexpr auto ys_to_rhs_major_ = Ys2RHsMajor{};
static constexpr auto ys_to_rhs_minor_ = Ys2RHsMinor{};
// redundant but useful info
// TODO: really bad code, should be over-hauled
struct detail
{
// ndim_rh_major_, ndim_span_mainor_
static constexpr index_t ndim_rh_major_ = NDimX + 1;
static constexpr index_t ndim_span_major_ = NDimX;
// ndims_rhs_minor_[ndim_rh_major_]
static constexpr auto ndims_rhs_minor_ = generate_array(
[](auto i) {
if constexpr(i.value == 0)
{
return rs_lengths_.size();
}
else
{
return hs_lengthss_[i - number<1>{}].size();
}
},
number<ndim_rh_major_>{});
// max_ndim_rh_minor_
static constexpr index_t max_ndim_rh_minor_ =
container_reduce(ndims_rhs_minor_, maximize<index_t>{}, 0);
// rhs_lengthss_[ndim_rh_major_][max_ndim_rh_minor_]
static constexpr auto rhs_lengthss_ =
to_array_of_array(container_concat(make_tuple(rs_lengths_), hs_lengthss_));
// ys_lengths_
static constexpr auto ys_lengths_ = [] {
array<index_t, NDimY> ys_lengths_tmp{-1};
for(index_t i = 0; i < NDimY; i++)
{
index_t rh_major = ys_to_rhs_major_[i];
index_t rh_minor = ys_to_rhs_minor_[i];
ys_lengths_tmp(i) = rhs_lengthss_[rh_major][rh_minor];
}
return ys_lengths_tmp;
}();
// rhs_major_minor_to_ys_[ndim_rh_majpr_][max_ndim_rh_minor_]
static constexpr auto rhs_major_minor_to_ys_ = [] {
array<array<index_t, max_ndim_rh_minor_>, NDimX + 1> rhs_major_minor_to_ys_tmp{{-1}};
static_for<0, NDimY, 1>{}([&](auto i) {
constexpr index_t rh_major = ys_to_rhs_major_[i];
constexpr index_t rh_minor = ys_to_rhs_minor_[i];
rhs_major_minor_to_ys_tmp(rh_major)(rh_minor) = i;
});
return rhs_major_minor_to_ys_tmp;
}();
// ndims_span_minor_[NDimY]
static constexpr auto ndims_span_minor_ = [] {
array<index_t, NDimX> ndims_span_minor{0};
for(index_t i = 0; i < NDimY; i++)
{
const index_t span_major = ys_to_rhs_major_[i] - 1;
ndims_span_minor(span_major)++;
}
return ndims_span_minor;
}();
// max_ndim_span_minor_
static constexpr index_t max_ndim_span_minor_ =
container_reduce(ndims_span_minor_, maximize<index_t>{}, 0);
// rhs_major_minor_to_span_minor_ [ndim_rh_major_][max_ndim_rh_minor_]
static constexpr auto rhs_major_minor_to_span_minor_ = [] {
array<array<index_t, max_ndim_rh_minor_>, ndim_rh_major_> rhs_major_minor_to_span_minor{
{-1}};
static_for<0, ndim_rh_major_, 1>{}([&](auto rh_major) {
constexpr index_t ndim_rh_minor = ndims_rhs_minor_[rh_major];
index_t cnt_ndim_span_minor = 0;
static_for<0, ndim_rh_minor, 1>{}([&](auto rh_minor) {
constexpr index_t idim_y = rhs_major_minor_to_ys_[rh_major][rh_minor];
if(idim_y >= 0)
{
rhs_major_minor_to_span_minor(rh_major)(rh_minor) = cnt_ndim_span_minor;
cnt_ndim_span_minor++;
}
});
});
return rhs_major_minor_to_span_minor;
}();
// ys_to_span_major_[NDimY]
static constexpr auto ys_to_span_major_ =
generate_array([](auto i) { return ys_to_rhs_major_[i] - 1; }, number<NDimY>{});
// ys_to_span_minor_[NDimY]
static constexpr auto ys_to_span_minor_ = generate_array(
[](auto i) {
return rhs_major_minor_to_span_minor_[ys_to_rhs_major_[i]][ys_to_rhs_minor_[i]];
},
number<NDimY>{});
// distributed_spans_lengthss_[ndim_span_major_][max_ndim_span_minor_]
static constexpr auto distributed_spans_lengthss_ = [] {
array<array<index_t, max_ndim_span_minor_>, ndim_span_major_>
distributed_spans_lengthss{{-1}};
static_for<0, NDimY, 1>{}([&](auto i) {
const index_t rh_major = ys_to_rhs_major_[i];
const index_t rh_minor = ys_to_rhs_minor_[i];
const index_t h_length = hs_lengthss_[number<rh_major - 1>{}][rh_minor];
const index_t span_major = rh_major - 1;
const index_t span_minor = rhs_major_minor_to_span_minor_[rh_major][rh_minor];
distributed_spans_lengthss(span_major)(span_minor) = h_length;
});
return distributed_spans_lengthss;
}();
// ndims_distributed_spans_minor_[ndim_span_major_]
static constexpr auto ndims_distributed_spans_minor_ = [] {
array<index_t, ndim_span_major_> ndims_distributed_spans_minor{0};
static_for<0, NDimY, 1>{}([&](auto i) {
const index_t span_major = ys_to_rhs_major_[i] - 1;
ndims_distributed_spans_minor(span_major)++;
});
return ndims_distributed_spans_minor;
}();
// does_p_own_r_[NDimP][NDimR]
static constexpr auto does_p_own_r_ = [] {
if constexpr(NDimR > 0)
{
array<array<bool, NDimR>, NDimP> does_p_own_r{{false}};
static_for<0, NDimP, 1>{}([&](auto idim_p) {
constexpr index_t ndim_low = ps_to_rhss_major_[idim_p].size();
static_for<0, ndim_low, 1>{}([&](auto idim_low) {
constexpr index_t rh_major = ps_to_rhss_major_[idim_p][idim_low];
constexpr index_t rh_minor = ps_to_rhss_minor_[idim_p][idim_low];
if constexpr(rh_major == 0)
{
does_p_own_r(idim_p)(rh_minor) = true;
}
});
});
return does_p_own_r;
}
else
{
return array<array<bool, NDimR>, NDimP>{};
}
}();
// ps_over_rs_derivative_[NDimP][NDimR]
static constexpr auto ps_over_rs_derivative_ = [] {
if constexpr(NDimR > 0)
{
array<array<index_t, NDimR>, NDimP> ps_over_rs_derivative{{0}};
static_for<0, NDimP, 1>{}([&](auto idim_p) {
constexpr index_t ndim_low = ps_to_rhss_major_[idim_p].size();
index_t p_over_rh_derivative = 1;
static_for<ndim_low - 1, -1, -1>{}([&](auto idim_low) {
constexpr index_t rh_major = ps_to_rhss_major_[idim_p][idim_low];
constexpr index_t rh_minor = ps_to_rhss_minor_[idim_p][idim_low];
constexpr index_t rh_length = rhs_lengthss_[rh_major][rh_minor];
if constexpr(rh_major == 0)
{
ps_over_rs_derivative(idim_p)(rh_minor) = p_over_rh_derivative;
}
p_over_rh_derivative *= rh_length;
});
});
return ps_over_rs_derivative;
}
else
{
return array<array<index_t, NDimR>, NDimP>{};
}
}();
// e.g. tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>> --> seq<3, 5> --> seq<0, 3, 8>
CK_TILE_HOST_DEVICE static constexpr auto get_h_dim_lengths_prefix_sum()
{
// <len_d0, len_d1, ...>
// e.g. tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>> --> seq<3, 5>
constexpr auto uniformed_h_dim_lengths = generate_sequence_v2(
[&](auto i) {
constexpr index_t size = HsLengthss{}[i].size();
return number<size>{};
},
number<NDimX>{});
// <0, len_d0, len_d0+len_d1, ...>
// e.g. seq<3, 5> --> seq<0, 3, 8>
constexpr auto h_dim_prefix_sum = prefix_sum_sequence(uniformed_h_dim_lengths);
return h_dim_prefix_sum;
}
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_idx_y_to_h()
{
constexpr auto all_ys_2_rhss = transform_sequences(
[](auto major, auto minor) constexpr {
// <0, 0, len_d0, len_d0+len_d1, ...>
constexpr auto x_dim_prefix_sum = merge_sequences(
sequence<0>{} /*for R dims*/, get_h_dim_lengths_prefix_sum());
return x_dim_prefix_sum.at(major) + minor;
},
Ys2RHsMajor{},
Ys2RHsMinor{});
return all_ys_2_rhss;
}
// return tuple<sorted_dims, sorted_maps, sorted_prefix_sum>
template <typename IdxSeq, typename PrefixSumSeq>
CK_TILE_HOST_DEVICE static constexpr auto get_sorted_info(IdxSeq, PrefixSumSeq)
{
using sorted_idx = sequence_unique_sort<IdxSeq, less<index_t>, equal<index_t>>;
constexpr auto sorted_dims = typename sorted_idx::type{};
constexpr auto sorted_maps = typename sorted_idx::sorted2unsorted_map{};
constexpr auto sorted_histogram =
histogram_sorted_sequence(sorted_dims, PrefixSumSeq{});
constexpr auto sorted_prefix_sum = prefix_sum_sequence(sorted_histogram);
return make_tuple(sorted_dims, sorted_maps, sorted_prefix_sum);
}
CK_TILE_HOST_DEVICE static constexpr auto get_sorted_y_info()
{
return get_sorted_info(get_uniformed_idx_y_to_h(), get_h_dim_lengths_prefix_sum());
}
CK_TILE_HOST_DEVICE void print() const
{
printf("tile_distribution_encoding::detail{");
//
printf("ndim_rh_major_: ");
print(ndim_rh_major_);
printf(", ");
//
printf("ndim_span_major_: ");
print(ndim_span_major_);
printf(", ");
//
printf("ndims_rhs_minor_: ");
print(ndims_rhs_minor_);
printf(", ");
//
printf("ndim_rh_major_: ");
print(ndim_rh_major_);
printf(", ");
//
printf("max_ndim_rh_minor_: ");
print(max_ndim_rh_minor_);
printf(", ");
//
printf("rhs_lengthss_: ");
print(rhs_lengthss_);
printf(", ");
//
printf("ys_lengths_: ");
print(ys_lengths_);
printf(", ");
//
printf("rhs_major_minor_to_ys_: ");
print(rhs_major_minor_to_ys_);
printf(", ");
//
printf("ndims_span_minor_: ");
print(ndims_span_minor_);
printf(", ");
//
printf("max_ndim_span_minor_: ");
print(max_ndim_span_minor_);
printf(", ");
//
printf("ys_to_span_major_: ");
print(ys_to_span_major_);
printf(", ");
//
printf("ys_to_span_minor_: ");
print(ys_to_span_minor_);
printf(", ");
//
printf("distributed_spans_lengthss_: ");
print(distributed_spans_lengthss_);
printf(", ");
//
printf("ndims_distributed_spans_minor_: ");
print(ndims_distributed_spans_minor_);
printf(", ");
//
printf("ps_over_rs_derivative_: ");
print(ps_over_rs_derivative_);
//
printf("}");
}
};
CK_TILE_HOST_DEVICE void print() const
{
printf("tile_distribution_encoding{");
//
printf("NDimX: %d, NDimP: %d, NDimY: %d, ", NDimX, NDimP, NDimY);
//
printf("rs_lengths_: ");
print(rs_lengths_);
printf(", ");
//
printf("hs_lengthss_: ");
print(hs_lengthss_);
printf(", ");
//
printf("ps_to_rhss_major_: ");
print(ps_to_rhss_major_);
printf(", ");
//
printf("ps_to_rhss_minor_: ");
print(ps_to_rhss_minor_);
printf(", ");
//
printf("ys_to_rhs_major_: ");
print(ys_to_rhs_major_);
printf(", ");
//
printf("ys_to_rhs_minor_: ");
print(ys_to_rhs_minor_);
printf(", ");
//
printf("detail: ");
print(detail{});
//
printf("}");
}
};
namespace detail {
template <typename OuterDstr, typename InnerDstr>
CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
{
static_assert(OuterDstr::NDimX == InnerDstr::NDimX, "wrong!");
constexpr index_t NDimHMajor = OuterDstr::NDimX;
using RsLengths =
sequence_merge_t<typename OuterDstr::RsLengths, typename InnerDstr::RsLengths>;
constexpr auto hs_lengthss = generate_tuple(
[&](auto i) {
return merge_sequences(typename OuterDstr::HsLengthss{}[i],
typename InnerDstr::HsLengthss{}[i]);
},
number<NDimHMajor>{});
//
constexpr auto rhs_major_2_ndim_outer_rhs_minor = [&]() {
array<index_t, NDimHMajor + 1> rhs_major_2_ndim_outer_rhs_minor_;
// R dimension
rhs_major_2_ndim_outer_rhs_minor_(0) = OuterDstr::RsLengths::size();
// Hs dimensions
static_for<0, NDimHMajor, 1>{}([&](auto i) {
rhs_major_2_ndim_outer_rhs_minor_(i + 1) = typename OuterDstr::HsLengthss{}[i].size();
});
return rhs_major_2_ndim_outer_rhs_minor_;
}();
// Ps2RHssMinor
constexpr auto updated_inner_ps_2_rhss_minor = generate_tuple(
[&](auto p) {
constexpr auto inner_p_2_rhss_major = typename InnerDstr::Ps2RHssMajor{}[p];
constexpr auto inner_p_2_rhss_minor = typename InnerDstr::Ps2RHssMinor{}[p];
constexpr index_t ndim_tmp = inner_p_2_rhss_minor.size();
constexpr auto updated_inner_p_2_rhss_minor = [&]() {
array<index_t, ndim_tmp> updated_inner_p_2_rhss_minor_;
for(index_t i = 0; i < ndim_tmp; i++)
{
index_t rh_major = inner_p_2_rhss_major[i];
index_t ndim_outer_h_minor = rhs_major_2_ndim_outer_rhs_minor[rh_major];
updated_inner_p_2_rhss_minor_(i) = inner_p_2_rhss_minor[i] + ndim_outer_h_minor;
}
return updated_inner_p_2_rhss_minor_;
}();
return TO_SEQUENCE(updated_inner_p_2_rhss_minor, ndim_tmp);
},
number<InnerDstr::NDimP>{});
// Ys2RHsMinor
constexpr auto updated_inner_ys_2_rhs_minor = [&]() {
constexpr auto inner_ys_2_rhs_major = typename InnerDstr::Ys2RHsMajor{};
constexpr auto inner_ys_2_rhs_minor = typename InnerDstr::Ys2RHsMinor{};
constexpr index_t ndim_tmp = inner_ys_2_rhs_minor.size();
constexpr auto updated_inner_ys_2_rhs_minor_ = [&]() {
array<index_t, ndim_tmp> updated_inner_ys_2_rhs_minor__;
for(index_t i = 0; i < ndim_tmp; i++)
{
index_t rh_major = inner_ys_2_rhs_major[i];
index_t ndim_outer_h_minor = rhs_major_2_ndim_outer_rhs_minor[rh_major];
updated_inner_ys_2_rhs_minor__(i) = inner_ys_2_rhs_minor[i] + ndim_outer_h_minor;
}
return updated_inner_ys_2_rhs_minor__;
}();
return TO_SEQUENCE(updated_inner_ys_2_rhs_minor_, ndim_tmp);
}();
//
constexpr auto ps_2_rhss_major =
container_concat(typename OuterDstr::Ps2RHssMajor{}, typename InnerDstr::Ps2RHssMajor{});
constexpr auto ps_2_rhss_minor =
container_concat(typename OuterDstr::Ps2RHssMinor{}, updated_inner_ps_2_rhss_minor);
//
constexpr auto ys_2_rhs_major =
merge_sequences(typename OuterDstr::Ys2RHsMajor{}, typename InnerDstr::Ys2RHsMajor{});
constexpr auto ys_2_rhs_minor =
merge_sequences(typename OuterDstr::Ys2RHsMinor{}, updated_inner_ys_2_rhs_minor);
return tile_distribution_encoding<RsLengths,
remove_cvref_t<decltype(hs_lengthss)>,
remove_cvref_t<decltype(ps_2_rhss_major)>,
remove_cvref_t<decltype(ps_2_rhss_minor)>,
remove_cvref_t<decltype(ys_2_rhs_major)>,
remove_cvref_t<decltype(ys_2_rhs_minor)>>{};
}
template <typename InDstr, index_t... InReduceDimXs>
CK_TILE_HOST_DEVICE constexpr auto
make_reduce_tile_distribution_encoding_impl(InDstr, sequence<InReduceDimXs...> reduce_dim_xs_in)
{
constexpr auto I1 = number<1>{};
// FIXME: increase if fail
constexpr index_t max_ndim_r_out = 20;
constexpr index_t max_ndim_y_out = 20;
//
constexpr index_t ndim_p = InDstr::NDimP;
constexpr index_t ndim_x_in = InDstr::NDimX;
constexpr index_t ndim_y_in = InDstr::NDimY;
constexpr index_t ndim_rh_major_in = InDstr::NDimX + 1;
constexpr index_t ndim_x_out = ndim_x_in - sizeof...(InReduceDimXs);
constexpr index_t max_ndim_rh_minor_in = InDstr::detail::max_ndim_rh_minor_;
// ndims_ps_low
constexpr auto ndims_ps_low = generate_array(
[&](auto i) { return InDstr::ps_to_rhss_major_[i].size(); }, number<ndim_p>{});
// is_rh_major_in_for_reduce
array<bool, ndim_rh_major_in> is_rh_major_in_for_reduce{false};
for(index_t i = 0; i < reduce_dim_xs_in.size(); i++)
{
index_t rh_major = reduce_dim_xs_in[i] + 1;
is_rh_major_in_for_reduce(rh_major) = true;
}
// is_y_in_for_reduce
array<bool, ndim_y_in> is_y_in_for_reduce{false};
for(index_t i = 0; i < ndim_y_in; i++)
{
index_t rh_major = InDstr::ys_to_rhs_major_[i];
if(is_rh_major_in_for_reduce[rh_major])
{
is_y_in_for_reduce(i) = true;
}
}
// is_rh_minor_in_for_y_reduce
array<array<bool, max_ndim_rh_minor_in>, ndim_rh_major_in> is_rh_minor_in_for_y_reduce{{false}};
static_for<0, ndim_y_in, 1>{}([&](auto i) {
index_t rh_major = InDstr::ys_to_rhs_major_[i];
index_t rh_minor = InDstr::ys_to_rhs_minor_[i];
if(is_y_in_for_reduce[i])
{
is_rh_minor_in_for_y_reduce(rh_major)(rh_minor) = true;
}
});
// in2out_rh_major
array<index_t, ndim_rh_major_in> in2out_rh_major{-1};
index_t cnt_ndim_rh_major_out = 0;
for(index_t i = 0; i < ndim_rh_major_in; i++)
{
if(is_rh_major_in_for_reduce[i])
{
in2out_rh_major(i) = 0;
}
else
{
in2out_rh_major(i) = cnt_ndim_rh_major_out;
cnt_ndim_rh_major_out++;
}
}
// rs_lengths_out, in2out_rh_minor
array<index_t, max_ndim_r_out> rs_lengths_out{-1};
array<array<index_t, max_ndim_rh_minor_in>, ndim_rh_major_in> in2out_rh_minor{{-1}};
// loop over input R dim
for(index_t i = 0; i < InDstr::rs_lengths_.size(); i++)
{
// rs_lengths_out
rs_lengths_out(i) = InDstr::rs_lengths_[i];
// in2out_rh_minor
in2out_rh_minor(0)(i) = i;
}
// loop over input H Dim
index_t cnt_ndim_r_out = InDstr::rs_lengths_.size();
static_for<1, ndim_rh_major_in, 1>{}([&](auto rh_major_in) {
constexpr auto h_major_in = rh_major_in - I1;
constexpr index_t ndim_rh_minor_in = InDstr::hs_lengthss_[h_major_in].size();
if(is_rh_major_in_for_reduce[rh_major_in])
{
for(index_t rh_minor_in = 0; rh_minor_in < ndim_rh_minor_in; rh_minor_in++)
{
if(not is_rh_minor_in_for_y_reduce[rh_major_in][rh_minor_in])
{
// rs_lengths_out
rs_lengths_out(cnt_ndim_r_out) = InDstr::hs_lengthss_[h_major_in][rh_minor_in];
// in2out_rh_minor
in2out_rh_minor(rh_major_in)(rh_minor_in) = cnt_ndim_r_out;
cnt_ndim_r_out++;
}
}
}
else
{
for(index_t rh_minor_in = 0; rh_minor_in < ndim_rh_minor_in; rh_minor_in++)
{
// in2out_rh_minor
in2out_rh_minor(rh_major_in)(rh_minor_in) = rh_minor_in;
}
}
});
// ndim_r_out
const index_t ndim_r_out = cnt_ndim_r_out;
// ndims_hs_minor_out, hs_lengthss_out
array<index_t, ndim_x_out> ndims_hs_minor_out{-1};
array<array<index_t, max_ndim_rh_minor_in>, ndim_x_out> hs_lengthss_out{{-1}};
index_t cnt_ndim_x_out = 0;
static_for<0, ndim_x_in, 1>{}([&](auto i) {
if(not is_rh_major_in_for_reduce[i + I1])
{
// ndims_hs_minor_out
ndims_hs_minor_out(cnt_ndim_x_out) = InDstr::hs_lengthss_[i].size();
// hs_lengthss_out
static_for<0, InDstr::hs_lengthss_[i].size(), 1>{}(
[&](auto j) { hs_lengthss_out(cnt_ndim_x_out)(j) = InDstr::hs_lengthss_[i][j]; });
cnt_ndim_x_out++;
}
});
// ps_to_rhss_major_out, ps_to_rhss_minor_out
array<array<index_t, max_ndim_rh_minor_in>, ndim_p> ps_to_rhss_major_out{{-1}};
array<array<index_t, max_ndim_rh_minor_in>, ndim_p> ps_to_rhss_minor_out{{-1}};
static_for<0, ndim_p, 1>{}([&](auto idim_p) {
static_for<0, InDstr::ps_to_rhss_major_[idim_p].size(), 1>{}([&](auto idim_low) {
index_t rh_major_in = InDstr::ps_to_rhss_major_[idim_p][idim_low];
index_t rh_minor_in = InDstr::ps_to_rhss_minor_[idim_p][idim_low];
ps_to_rhss_major_out(idim_p)(idim_low) = in2out_rh_major[rh_major_in];
ps_to_rhss_minor_out(idim_p)(idim_low) = in2out_rh_minor[rh_major_in][rh_minor_in];
});
});
// ys_to_rhs_major_out, ys_to_rhs_minor_out
array<index_t, max_ndim_y_out> ys_to_rhs_major_out{-1};
array<index_t, max_ndim_y_out> ys_to_rhs_minor_out{-1};
index_t cnt_ndim_y_out = 0;
static_for<0, ndim_y_in, 1>{}([&](auto i) {
if(not is_y_in_for_reduce[i])
{
index_t rh_major_in = InDstr::ys_to_rhs_major_[i];
index_t rh_minor_in = InDstr::ys_to_rhs_minor_[i];
ys_to_rhs_major_out(cnt_ndim_y_out) = in2out_rh_major[rh_major_in];
ys_to_rhs_minor_out(cnt_ndim_y_out) = in2out_rh_minor[rh_major_in][rh_minor_in];
cnt_ndim_y_out++;
}
});
// ndim_y_out
const index_t ndim_y_out = cnt_ndim_y_out;
//
return make_tuple(ndim_x_out,
ndim_p,
ndim_y_out,
ndim_r_out,
ndims_hs_minor_out,
ndims_ps_low,
rs_lengths_out,
hs_lengthss_out,
ps_to_rhss_major_out,
ps_to_rhss_minor_out,
ys_to_rhs_major_out,
ys_to_rhs_minor_out);
}
template <typename InDstr, index_t... InReduceDimXs>
CK_TILE_HOST_DEVICE constexpr auto
make_reduce_tile_distribution_encoding(InDstr, sequence<InReduceDimXs...> reduce_dim_xs_in)
{
constexpr auto impl = make_reduce_tile_distribution_encoding_impl(InDstr{}, reduce_dim_xs_in);
constexpr index_t ndim_x = impl.template at<0>();
constexpr index_t ndim_p = impl.template at<1>();
constexpr index_t ndim_y = impl.template at<2>();
constexpr index_t ndim_r = impl.template at<3>();
constexpr auto ndims_hs_minor = impl.template at<4>();
constexpr auto ndims_ps_low = impl.template at<5>();
constexpr auto rs_lengths_impl = impl.template at<6>();
constexpr auto hs_lengthss_impl = impl.template at<7>();
constexpr auto ps_to_rhss_major_impl = impl.template at<8>();
constexpr auto ps_to_rhss_minor_impl = impl.template at<9>();
constexpr auto ys_to_rhs_major_impl = impl.template at<10>();
constexpr auto ys_to_rhs_minor_impl = impl.template at<11>();
constexpr auto rs_lengths = TO_SEQUENCE(rs_lengths_impl, ndim_r);
constexpr auto hs_lengthss = TO_TUPLE_OF_SEQUENCE(hs_lengthss_impl, ndim_x, ndims_hs_minor);
constexpr auto ps_to_rhss_major =
TO_TUPLE_OF_SEQUENCE(ps_to_rhss_major_impl, ndim_p, ndims_ps_low);
constexpr auto ps_to_rhss_minor =
TO_TUPLE_OF_SEQUENCE(ps_to_rhss_minor_impl, ndim_p, ndims_ps_low);
constexpr auto ys_to_rhs_major = TO_SEQUENCE(ys_to_rhs_major_impl, ndim_y);
constexpr auto ys_to_rhs_minor = TO_SEQUENCE(ys_to_rhs_minor_impl, ndim_y);
return tile_distribution_encoding<remove_cvref_t<decltype(rs_lengths)>,
remove_cvref_t<decltype(hs_lengthss)>,
remove_cvref_t<decltype(ps_to_rhss_major)>,
remove_cvref_t<decltype(ps_to_rhss_minor)>,
remove_cvref_t<decltype(ys_to_rhs_major)>,
remove_cvref_t<decltype(ys_to_rhs_minor)>>{};
}
} // namespace detail
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/null_tensor.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
// TODO: support tensors with different distribution
template <typename InOutElementFunc,
typename... InOutDstrTensors,
typename = std::enable_if_t<std::conjunction_v<
std::negation<std::is_same<std::remove_const_t<InOutDstrTensors>, null_tensor>>...>>>
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc& inout_element_func,
InOutDstrTensors&... inout_dstr_tensors)
{
// TODO: make sure all distributed tensors have same lengths and distribution
// static_assert(xxx);
constexpr index_t thread_buffer_size =
__type_pack_element<0, InOutDstrTensors...>::get_thread_buffer_size();
static_for<0, thread_buffer_size, 1>{}(
[&](auto i) { inout_element_func(inout_dstr_tensors.get_thread_buffer().at(i)...); });
}
template <typename InElementFunc,
typename... InTensor,
typename = std::enable_if_t<
std::conjunction_v<std::negation<std::is_same<InTensor, null_tensor>>...>>>
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc& in_element_func,
const InTensor&... in_dstr_tensors)
{
using OutDataType = decltype(in_element_func(typename InTensor::DataType{}...));
// TODO: make sure all distributed tensors have same lengths and distribution
// static_assert(xxx);
constexpr auto in_tile_dstr = __type_pack_element<0, InTensor...>::get_tile_distribution();
constexpr index_t thread_buffer_size =
__type_pack_element<0, InTensor...>::get_thread_buffer_size();
auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
out_dstr_tensor.get_thread_buffer()(i) =
in_element_func(in_dstr_tensors.get_thread_buffer()[i]...);
});
return out_dstr_tensor;
}
template <typename DstrTensors, typename T>
CK_TILE_DEVICE void set_tile(DstrTensors& dstr_tensor, const T& value)
{
tile_elementwise_inout(
[&value](auto& x) {
x = type_convert<typename DstrTensors::DataType, remove_cvref_t<T>>(value);
},
dstr_tensor);
}
template <typename T>
CK_TILE_DEVICE void set_tile(null_tensor&, const T&)
{
}
// TODO: prefer to use per-dword value to set a tensor, in case compiler not doing well with
// sub-dword tensor...
template <typename DstrTensors, index_t v>
CK_TILE_DEVICE void set_tile(DstrTensors& dstr_tensor, number<v>)
{
constexpr index_t tensor_bytes =
DstrTensors::get_thread_buffer_size() * sizeof(typename DstrTensors::DataType);
if constexpr(v == 0 && tensor_bytes % 4 == 0)
{
using dvec_t = array<index_t, tensor_bytes / 4>;
auto& tensor = reinterpret_cast<dvec_t&>(dstr_tensor.get_thread_buffer());
for(auto i = 0; i < tensor.size(); i++)
tensor.get(i) = v;
}
else
{
tile_elementwise_inout(
[](auto& x) { x = type_convert<typename DstrTensors::DataType, index_t>(v); },
dstr_tensor);
}
}
template <index_t v>
CK_TILE_DEVICE void set_tile(null_tensor&, number<v>)
{
}
template <typename DstrTensors>
CK_TILE_DEVICE void clear_tile(DstrTensors& dstr_tensor)
{
set_tile(dstr_tensor, 0);
}
namespace impl {
// TODO: this is ugly
template <typename OutDataType, typename InTensor>
CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// This API is designed to use the _pk_ serious of function
constexpr auto in_tile_dstr = InTensor::get_tile_distribution();
constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size();
static_assert(thread_buffer_size % 4 == 0);
constexpr index_t thread_buffer_size_pk = thread_buffer_size / 4;
auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wuninitialized"
// __builtin_amdgcn_cvt_pk_fp8_f32() this builtin require the old value, and
// will generate a v_mov_b32 vxxx [old] before cvt, which result in unwanted ISA
// so we prepare an uninitialized variable purposely, and turn off the warning
int dummy_old;
static_for<0, thread_buffer_size_pk, 1>{}([&](auto i) {
uint32_t x = __builtin_amdgcn_cvt_pk_fp8_f32(
in_dstr_tensors.get_thread_buffer()[number<4 * i + 0>{}],
in_dstr_tensors.get_thread_buffer()[number<4 * i + 1>{}],
dummy_old,
false); // false -> WORD0
uint32_t y = __builtin_amdgcn_cvt_pk_fp8_f32(
in_dstr_tensors.get_thread_buffer()[number<4 * i + 2>{}],
in_dstr_tensors.get_thread_buffer()[number<4 * i + 3>{}],
dummy_old,
false); // false -> WORD0
constexpr int32_t m0 = 0x05040100;
using vec_t = array<OutDataType, 4>;
vec_t d = bit_cast<vec_t>(__builtin_amdgcn_perm(y, x, m0));
out_dstr_tensor.get_thread_buffer().template set_as<vec_t>(number<i>{}, d);
});
#pragma clang diagnostic pop
return out_dstr_tensor;
#else
// fallback
return tile_elementwise_in(type_convert<OutDataType, typename InTensor::DataType>,
in_dstr_tensors);
#endif
}
#if CK_TILE_USE_SUBDWORD_TILE_CAST
// this function assume either src or dst (or both) date type is under 1 dword
// we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy)
template <typename OutDataType, typename InTensor>
CK_TILE_DEVICE auto cast_tile_opt_subdword(const InTensor& in_dstr_tensors)
{
constexpr auto in_tile_dstr = InTensor::get_tile_distribution();
auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
using i_type = remove_cvref_t<typename InTensor::DataType>;
using o_type = remove_cvref_t<OutDataType>;
constexpr index_t i_elem_bytes = sizeof(i_type);
constexpr index_t o_elem_bytes = sizeof(o_type);
static_assert(i_elem_bytes < 4 || o_elem_bytes < 4);
constexpr index_t bulk_size =
(i_elem_bytes >= o_elem_bytes) ? (4 / o_elem_bytes) : (4 / i_elem_bytes);
static_assert(bulk_size != 0);
using o_bulk_type =
std::conditional_t<i_elem_bytes >= o_elem_bytes, float, array<o_type, bulk_size>>;
constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size();
constexpr index_t iters = thread_buffer_size / bulk_size;
constexpr index_t rems = thread_buffer_size % bulk_size;
// cast the sequence per-bulk
static_for<0, iters, 1>{}([&](auto i) {
union bulk_wrapper
{
o_bulk_type bulk{};
o_type data[bulk_size];
} o_bulk;
// TODO: should use below function, but somehow will result in spill (same as c-forloop)
static_for<0, bulk_size, 1>{}([&o_bulk, &in_dstr_tensors, &i](auto ib) {
o_bulk.data[ib.value] = static_cast<o_type>(
in_dstr_tensors.get_thread_buffer()
.template get_as<i_type>()[number<bulk_size * i.value + ib.value>{}]);
});
// TODO: fixme, should use above!
// static_assert(sizeof(i_type) / sizeof(o_type) == 2);
// o_bulk.data[0] = static_cast<o_type>(
// in_dstr_tensors.get_thread_buffer().template get_as<i_type>()[number<2 * i + 0>{}]);
// o_bulk.data[1] = static_cast<o_type>(
// in_dstr_tensors.get_thread_buffer().template get_as<i_type>()[number<2 * i + 1>{}]);
out_dstr_tensor.get_thread_buffer().template set_as<o_bulk_type>(i, o_bulk.bulk);
});
static_for<0, rems, 1>{}([&](auto r) {
// TODO: introducing local scratch pad?
auto idx = number<iters * bulk_size + r>{};
out_dstr_tensor.get_thread_buffer().at(idx) =
static_cast<o_type>(in_dstr_tensors.get_thread_buffer().at(idx));
});
return out_dstr_tensor;
}
#endif
} // namespace impl
template <typename DstType, typename SrcTensor>
CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor)
{
if constexpr((std::is_same_v<DstType, fp8_t> ||
std::is_same_v<DstType, bf8_t>)&&std::is_same_v<typename SrcTensor::DataType,
float> &&
(SrcTensor::get_thread_buffer_size() % 4 == 0))
{
return impl::cast_tile_pk_fp8x4<DstType, SrcTensor>(src_tensor);
}
#if CK_TILE_USE_SUBDWORD_TILE_CAST
else if constexpr(sizeof(DstType) < 4 || sizeof(typename SrcTensor::DataType) < 4)
{
return impl::cast_tile_opt_subdword<DstType, SrcTensor>(src_tensor);
}
#endif
else
return tile_elementwise_in(type_convert<DstType, typename SrcTensor::DataType>, src_tensor);
}
// no-op function for null_tensor arguments
template <typename InOutElementFunc,
typename... MaybeNullTensor,
typename = std::enable_if_t<
std::disjunction_v<std::is_same<remove_cvref_t<MaybeNullTensor>, null_tensor>...>>>
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc&, MaybeNullTensor&&...)
{
}
// no-op function for null_tensor arguments
template <typename InElementFunc,
typename... MaybeNullTensor,
typename = std::enable_if_t<
std::disjunction_v<std::is_same<remove_cvref_t<MaybeNullTensor>, null_tensor>...>>>
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc&, MaybeNullTensor&&...)
{
return null_tensor{};
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/arch/utility.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
template <typename BottomTensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
index_t NumCoord>
struct tile_window_with_static_distribution
{
using BottomTensorView = remove_reference_t<BottomTensorView_>;
using WindowLengths = remove_cvref_t<WindowLengths_>;
using TileDstr = remove_cvref_t<StaticTileDistribution_>;
using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor;
using BottomTensorDesc = typename BottomTensorView::TensorDesc;
using DataType = remove_cvref_t<typename BottomTensorView::DataType>;
static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension();
static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension();
static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p();
static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y();
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
// TODO: check WindowLengths and StaticTileDistribution are consistent
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
"wrong! lengths should be static");
static_assert(TileDstr::is_static(), "wrong!");
static_assert(NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(),
"wrong! inconsistent # of diemsnions");
using AdaptorTopIndex = array<index_t, NDimWindowAdaptorTop>;
using BottomTensorIndex = array<index_t, NDimBottomTensor>;
using WindowAdaptorCoord =
decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{}));
using BottomTensorCoord =
decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{}));
struct load_store_traits
{
private:
static constexpr auto get_vector_dim_y_scalar_per_vector()
{
const auto [ys_vector_lengths, ys_vector_strides] =
tile_window_with_static_distribution::
get_window_adaptor_ys_safe_vector_length_strides();
index_t VectorDimY_ = 0;
index_t ScalarPerVector_ = 1;
for(index_t i = 0; i < NDimY; ++i)
{
if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_)
{
ScalarPerVector_ = ys_vector_lengths[i];
VectorDimY_ = i;
}
}
return make_tuple(VectorDimY_, ScalarPerVector_);
}
public:
static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>();
static constexpr index_t ScalarPerVector =
get_vector_dim_y_scalar_per_vector().template at<1>();
// using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
// using vector_t = typename vector_type_t::type;
using vector_t = thread_buffer<DataType, ScalarPerVector>;
private:
static constexpr auto scalars_per_access_ = [] {
constexpr auto scalars_per_access_arr = generate_array(
[&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number<NDimY>{});
/// TODO: add non-automatic storage argument support to macro TO_SEQUENCE()
constexpr auto NDimY_ = NDimY;
return TO_SEQUENCE(scalars_per_access_arr, NDimY_);
}();
static constexpr auto get_space_filling_curve()
{
constexpr auto tile_dstr = TileDstr{};
constexpr auto thread_tensor_lengths_ys =
to_sequence(tile_dstr.get_ys_to_d_descriptor().get_lengths());
// FIXME: need logic to judge dim access order
using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
return space_filling_curve<decltype(thread_tensor_lengths_ys),
DimAccessOrder,
decltype(scalars_per_access_)>{};
}
public:
using SFC_Ys = decltype(get_space_filling_curve());
static constexpr index_t NumAccess = SFC_Ys::get_num_of_access();
static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0");
static_assert(NumAccess % NumCoord == 0, "wrong! # of access is not divisible by NumCoord");
};
static constexpr index_t NumAccessPerCoord = load_store_traits::NumAccess / NumCoord;
CK_TILE_DEVICE constexpr tile_window_with_static_distribution() = default;
CK_TILE_DEVICE constexpr tile_window_with_static_distribution(
const BottomTensorView& bottom_tensor_view,
const WindowLengths& window_lengths,
const BottomTensorIndex& window_origin,
const TileDstr& tile_distribution)
: bottom_tensor_view_{bottom_tensor_view},
window_lengths_{window_lengths},
window_origin_{window_origin},
tile_dstr_{tile_distribution},
pre_computed_coords_{}
{
#if 0 // debug
// TODO: this use more register for FA, but less register for GEMM
// need investigation
// only support warp-tile and block-tile
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
WindowAdaptorCoord window_adaptor_thread_coord_tmp;
if constexpr(NDimP == 1)
{
window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_distribution.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0});
}
else if constexpr(NDimP == 2)
{
window_adaptor_thread_coord_tmp =
make_tensor_adaptor_coordinate(tile_distribution.get_ps_ys_to_xs_adaptor(),
AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
}
#else
// TODO: this use less register for FA, but more register for GEMM
// need investigation
const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_distribution.get_ps_ys_to_xs_adaptor(),
container_concat(detail::get_partition_index(tile_distribution),
array<index_t, NDimY>{0}));
#endif
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
// pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
// future load/store() calls (might allocate more registers)
using Traits = load_store_traits;
using SFC_Ys = typename Traits::SFC_Ys;
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
constexpr auto idx_diff_ys =
SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
constexpr auto idx_diff_ps_ys = container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
pre_computed_coords_(iCoord) =
make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
});
}
CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; }
CK_TILE_DEVICE static constexpr bool has_static_tile_distribution()
{
return TileDstr::is_static();
}
CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; }
CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; }
CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
// move thread's window adaptor coordinate and bottom tensor coordinate
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(
WindowAdaptorCoord& window_adaptor_thread_coord,
BottomTensorCoord& bottom_tensor_thread_coord,
const AdaptorTopIndex& idx_diff_adaptor_top) const
{
array<index_t, NDimBottomTensor> idx_diff_adaptor_bottom;
move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
window_adaptor_thread_coord,
idx_diff_adaptor_top,
idx_diff_adaptor_bottom);
move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
bottom_tensor_thread_coord,
idx_diff_adaptor_bottom);
}
// return vector dimension among [y0, y1, ...]
CK_TILE_DEVICE static constexpr auto get_window_adaptor_ys_safe_vector_length_strides()
{
// bottom tensor top dimension vector lengths and strides
const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] =
BottomTensorDesc::get_top_dimension_safe_vector_length_strides();
// window vector lengths/strides
const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths;
const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides;
// window adaptor [p0, p1, ..., y0, y1, ...]
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_lengths{
-1};
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_strides{
-1};
constexpr auto window_adaptor_bottom_dims =
WindowAdaptor::get_bottom_dimension_hidden_ids();
set_container_subset(window_adaptor_vector_lengths,
window_adaptor_bottom_dims,
window_adaptor_bottom_dim_vector_lengths);
set_container_subset(window_adaptor_vector_strides,
window_adaptor_bottom_dims,
window_adaptor_bottom_dim_vector_strides);
const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] =
WindowAdaptor{}.get_top_dimension_safe_vector_length_strides(
window_adaptor_vector_lengths, window_adaptor_vector_strides);
// [y0, y1, ...]
constexpr auto y_dims = typename arithmetic_sequence_gen<TileDstr::get_num_of_dimension_p(),
NDimWindowAdaptorTop,
1>::type{};
return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims),
get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims));
}
CK_TILE_DEVICE constexpr auto get_num_access() const { return load_store_traits::NumAccess; }
template <bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(bool_constant<oob_conditional_check> = {}) const
{
using Traits = load_store_traits;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
// read from bottom tensor
const vector_t vec_value =
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord, bool_constant<oob_conditional_check>{});
#if 1
// write into distributed tensor
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
dst_tensor.get_thread_buffer().template at<d>() =
vec_value.template get_as<DataType>()[j];
});
#else
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
static_assert(d % Traits::ScalarPerVector == 0);
dst_tensor.get_thread_buffer().template get_as<vector_t>()(
number<d / Traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
#endif
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys =
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
return dst_tensor;
}
template <typename DstTile, bool oob_conditional_check = true>
CK_TILE_DEVICE void load_raw(DstTile& dst_tensor,
bool_constant<oob_conditional_check> = {}) const
{
using Traits = load_store_traits;
// using vector_type_t = typename Traits::vector_type_t;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
static constexpr index_t YElementSize =
TileDstr{}.get_ys_to_d_descriptor().get_element_space_size();
static_assert(YElementSize % Traits::ScalarPerVector == 0);
using vectorized_tbuf = array<vector_t, YElementSize / Traits::ScalarPerVector>;
// StaticBuffer<address_space_enum::vgpr,
// vector_t,
// YElementSize / Traits::ScalarPerVector,
// true>;
constexpr auto tile_dstr = TileDstr{};
auto& dst_vec_tbuf = reinterpret_cast<vectorized_tbuf&>(dst_tensor.get_thread_buffer());
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
static_assert(d % Traits::ScalarPerVector == 0);
get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
dst_vec_tbuf.template at<d / Traits::ScalarPerVector>(),
bottom_tensor_thread_coord,
bool_constant<oob_conditional_check>{});
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys =
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
}
// TODO: currently async load only implemented in inline asm
template <typename LdsTileWindow_, bool oob_conditional_check = true>
CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile,
bool_constant<oob_conditional_check> = {}) const
{
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
// using LdsTensorView = typename LdsTileWindow::BottomTensorView;
using LdsDataType = typename LdsTileWindow::DataType;
// using LdsDescriptor = typename LdsTileWindow::BottomTensorDesc;
// issues * warps * lanes
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
const index_t size_per_buf =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<0>{}, number<0>{})) *
sizeof(LdsDataType);
const index_t size_per_wave =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<1>{}, number<0>{})) *
sizeof(LdsDataType) -
size_per_buf;
const index_t size_per_issue =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<1>{}, number<0>{}, number<0>{})) *
sizeof(LdsDataType) -
size_per_buf;
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
m0_set_with_memory(m0_init_value); // This should be wave independent
using Traits = load_store_traits;
// using vector_type_t = typename Traits::vector_type_t;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// read from bottom tensor
get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
smem, bottom_tensor_thread_coord);
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys =
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
m0_inc_with_memory(size_per_issue);
}
});
});
}
template <bool oob_conditional_check = true>
CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
bool_constant<oob_conditional_check> = {}) const
{
using Traits = load_store_traits;
// using vector_type_t = typename Traits::vector_type_t;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
// read from distributed tensor
// vector_type_t vec;
vector_t vec_value;
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
vec_value.template get_as<DataType>()(j) =
dstr_tensor.get_thread_buffer().template at<d>();
});
// const vector_t vec_value = vec.template get_as<vector_t>().template at<0>();
// write into bottom tensor
get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
bottom_tensor_thread_coord, vec_value, bool_constant<oob_conditional_check>{});
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys =
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
}
CK_TILE_DEVICE void
store_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor) const
{
using Traits = load_store_traits;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
static constexpr bool oob_conditional_check = true;
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
// read from distributed tensor
vector_t vec_value;
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
vec_value.template get_as<DataType>()(j) =
dstr_tensor.get_thread_buffer().template at<d>();
});
// write into bottom tensor
get_bottom_tensor_view()
.template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
bottom_tensor_thread_coord, vec_value);
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys =
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
}
// move thread's botom tensor coordiante
// [x0', x1', ... ] ==> [offset]
// also move window-origin
CK_TILE_DEVICE void move(const BottomTensorIndex& step)
{
window_origin_ += step;
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
pre_computed_coords_(iCoord)(I1),
step);
});
}
// this is the bottom tensor view
// [x0', x1', ...] ==> [offset]
BottomTensorView bottom_tensor_view_;
//
WindowLengths window_lengths_;
// origin ([x0', x1', ...]) of window on bottom tensor
BottomTensorIndex window_origin_;
// Tile tensor distribution, which contains:
// 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...]
// 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d]
TileDstr tile_dstr_;
// this contains:
// per-thread coordinate for window adaptor
// per-thread coordinate for bottom tensor
array<tuple<WindowAdaptorCoord, BottomTensorCoord>, NumCoord> pre_computed_coords_;
};
// TODO: use strategy
template <typename TensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
index_t NumCoord = 1>
CK_TILE_DEVICE constexpr auto
make_tile_window(const TensorView_& tensor_view,
const WindowLengths_& window_lengths,
const multi_index<TensorView_::get_num_of_dimension()>& origin,
const StaticTileDistribution_& tile_distribution,
number<NumCoord> = {})
{
return tile_window_with_static_distribution<remove_cvref_t<TensorView_>,
remove_cvref_t<WindowLengths_>,
remove_cvref_t<StaticTileDistribution_>,
NumCoord>{
tensor_view, window_lengths, origin, tile_distribution};
}
template <typename TensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
index_t NumCoord>
CK_TILE_DEVICE void move_tile_window(
tile_window_with_static_distribution<TensorView_,
WindowLengths_,
StaticTileDistribution_,
NumCoord>& window,
const typename tile_window_with_static_distribution<TensorView_,
WindowLengths_,
StaticTileDistribution_,
NumCoord>::BottomTensorIndex& step)
{
window.move(step);
}
template <typename BottomTensorView_, typename WindowLengths_>
struct tile_window_with_static_lengths
{
using BottomTensorView = remove_reference_t<BottomTensorView_>;
using WindowLengths = remove_cvref_t<WindowLengths_>;
using BottomTensorDesc = typename BottomTensorView::TensorDesc;
using DataType = typename BottomTensorView::DataType;
static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension();
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
"wrong! lengths should be static");
using BottomTensorIndex = array<index_t, NDimBottomTensor>;
CK_TILE_DEVICE constexpr tile_window_with_static_lengths() = default;
CK_TILE_DEVICE constexpr tile_window_with_static_lengths(
const BottomTensorView& bottom_tensor_view,
const WindowLengths& window_lengths,
const BottomTensorIndex& window_origin)
: bottom_tensor_view_{bottom_tensor_view},
window_lengths_{window_lengths},
window_origin_{window_origin}
{
}
CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; }
CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; }
CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
// move window-origin
CK_TILE_DEVICE void move(const BottomTensorIndex& step) { window_origin_ += step; }
// this is the bottom tensor view
// [x0', x1', ...] ==> [offset]
BottomTensorView bottom_tensor_view_;
//
WindowLengths window_lengths_;
// origin ([x0', x1', ...]) of window on bottom tensor
BottomTensorIndex window_origin_;
};
template <typename TensorView_, typename WindowLengths_>
CK_TILE_DEVICE constexpr auto
make_tile_window(const TensorView_& tensor_view,
const WindowLengths_& window_lengths,
const multi_index<TensorView_::get_num_of_dimension()>& origin)
{
static_assert(ck_tile::is_known_at_compile_time<WindowLengths_>::value,
"wrong! lengths should be static");
return tile_window_with_static_lengths<remove_cvref_t<TensorView_>,
remove_cvref_t<WindowLengths_>>{
tensor_view, window_lengths, origin};
}
template <typename TensorView_, typename WindowLengths_>
CK_TILE_DEVICE void move_tile_window(
tile_window_with_static_lengths<TensorView_, WindowLengths_>& window,
const typename tile_window_with_static_lengths<TensorView_, WindowLengths_>::BottomTensorIndex&
step)
{
window.move(step);
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
namespace ck_tile {
template <typename Y, typename X>
CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X& x)
{
static_assert(__has_builtin(__builtin_bit_cast), "");
static_assert(sizeof(X) == sizeof(Y), "Do not support cast between different size of type");
return __builtin_bit_cast(Y, x);
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include <stdint.h>
#include <utility>
namespace ck_tile {
namespace detail {
struct swallow
{
template <typename... Ts>
CK_TILE_HOST_DEVICE constexpr swallow(Ts&&...)
{
}
};
template <class>
struct static_for_impl;
template <index_t... Is>
struct static_for_impl<sequence<Is...>>
{
template <class F>
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
{
swallow{(f(number<Is>{}), 0)...};
}
};
} // namespace detail
// F signature: F(number<Iter>)
template <index_t NBegin, index_t NEnd, index_t Increment>
struct static_for
{
CK_TILE_HOST_DEVICE constexpr static_for()
{
static_assert(Increment != 0 && (NEnd - NBegin) % Increment == 0,
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0");
static_assert((Increment > 0 && NBegin <= NEnd) || (Increment < 0 && NBegin >= NEnd),
"wrongs! should (Increment > 0 && NBegin <= NEnd) || (Increment < 0 && "
"NBegin >= NEnd)");
}
template <class F>
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
{
detail::static_for_impl<typename arithmetic_sequence_gen<NBegin, NEnd, Increment>::type>{}(
f);
}
};
struct identity
{
template <typename T>
CK_TILE_HOST_DEVICE constexpr T&& operator()(T&& arg) const noexcept
{
return std::forward<T>(arg);
}
};
namespace detail {
// RemainLengths: sequence<...>
// Orders: sequence<...>
template <class RemainLengths, class Orders>
struct static_ford_impl
{
CK_TILE_HOST_DEVICE constexpr static_ford_impl()
{
static_assert(RemainLengths::size() > 0, "wrong! should not get here");
}
// F signature: F(sequence<...>)
// CurrentOrderedId: sequence<...>
template <class F, class CurrentOrderedId>
CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentOrderedId) const
{
static_for<0, RemainLengths::front(), 1>{}([=](auto I) {
static_ford_impl<decltype(RemainLengths::pop_front()), Orders>{}(
f, CurrentOrderedId::push_back(I));
});
}
};
template <class Orders>
struct static_ford_impl<sequence<>, Orders>
{
// F signature: F(sequence<...>)
// OrderedId: sequence<...>
template <class F, class OrderedId>
CK_TILE_HOST_DEVICE constexpr void operator()(F f, OrderedId) const
{
// retrive unordered Id
f(OrderedId::reorder_old_to_new(Orders{}));
}
};
} // namespace detail
// Lengths is sequence<...>, it is the length of each dimension for
// N-dimensional loop
// Orders is sequence<...>, it is the order of dimension in which static_ford
// will loop over each
// dimension
template <class Lengths,
class Orders = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type>
struct static_ford
{
CK_TILE_HOST_DEVICE constexpr static_ford()
{
static_assert(Lengths::size() > 0, "wrong! Lengths is empty");
static_assert(Lengths::size() == Orders::size(), "wrong! inconsistent size");
}
// F signature: F(sequence<...> multi_id)
// multi_id is the unordered multi-index
template <class F>
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
{
constexpr auto ordered_lengths = Lengths::reorder_new_to_old(Orders{});
detail::static_ford_impl<decltype(ordered_lengths), Orders>{}(f, sequence<>{});
}
};
namespace detail {
template <typename Indices>
struct unpack_impl;
template <index_t... Is>
struct unpack_impl<sequence<Is...>>
{
template <typename F, typename X>
CK_TILE_HOST_DEVICE constexpr auto operator()(F&& f, X&& x) const
{
#if 0
return std::forward<F>(f)(std::forward<X>(x).at(number<Is>{})...);
#else
return std::forward<F>(f)(std::forward<X>(x).template at<Is>()...);
#endif
}
};
template <typename Seq0, typename Seq1>
struct unpack2_impl;
// TODO: remove this, after properly implementing unpack that takes any number of containers
template <index_t... Is, index_t... Js>
struct unpack2_impl<sequence<Is...>, sequence<Js...>>
{
template <typename F, typename X, typename Y>
CK_TILE_HOST_DEVICE constexpr auto operator()(F&& f, X&& x, Y&& y) const
{
#if 0
return std::forward<F>(f)(std::forward<X>(x).at(number<Is>{})...,
std::forward<Y>(y).at(number<Js>{})...);
#else
return std::forward<F>(f)(std::forward<X>(x).template at<Is>()...,
std::forward<Y>(y).template at<Js>()...);
#endif
}
};
} // namespace detail
template <typename F, typename X>
CK_TILE_HOST_DEVICE constexpr auto unpack(F&& f, X&& x)
{
using X_ = remove_reference_t<X>;
return detail::unpack_impl<typename arithmetic_sequence_gen<0, X_::size(), 1>::type>{}(
std::forward<F>(f), std::forward<X>(x));
}
// TODO: properly implement unpack that takes any number of containers
template <typename F, typename X, typename Y>
CK_TILE_HOST_DEVICE constexpr auto unpack2(F&& f, X&& x, Y&& y)
{
using X_ = remove_reference_t<X>;
using Y_ = remove_reference_t<Y>;
return detail::unpack2_impl<typename arithmetic_sequence_gen<0, X_::size(), 1>::type,
typename arithmetic_sequence_gen<0, Y_::size(), 1>::type>{}(
std::forward<F>(f), std::forward<X>(x), std::forward<Y>(y));
}
// z = predicate ? x : y
template <bool predicate, typename X, typename Y>
constexpr auto conditional_expr(X&& x, Y&& y)
{
if constexpr(predicate)
{
return std::forward<X>(x);
}
else
{
return std::forward<Y>(y);
}
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
// https://en.cppreference.com/w/cpp/utility/tuple/ignore
namespace ck_tile {
namespace detail {
struct ignore_t
{
template <typename T>
constexpr void operator=(T&&) const noexcept
{
}
};
} // namespace detail
inline constexpr detail::ignore_t ignore;
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include <stdint.h>
namespace ck_tile {
// magic number division
// Caution:
// 1. For uint32_t as dividend: magic number division implementation being used would produce
// correct result if the dividend is uint32_t and its value is within 31-bit value range.
// 2. For int32_t as dividendd: magic number division for int32_t dividened has not been
// implemented, the int32_t dividend would be bit-wise interpreted as uint32_t and magic number
// division implementation for uint32_t is then used. Therefore, dividend value need to be
// non-negative.
// TODO:
// 1. Implement magic number divison for int32_t
// 2. Implement magic number divison for unit32_t with 32-bit value range
struct magic_division32_bit_range
{
// uint32_t
CK_TILE_HOST_DEVICE static constexpr auto calculate_magic_numbers(uint32_t divisor)
{
// WARNING: magic division is only valid for division inside this range.
// assert(divisor >= 1 && divisor <= INT32_MAX)
uint32_t shift_u32 = 0;
while((1U << shift_u32) < divisor)
{
shift_u32++;
};
uint64_t tmp_u64 = ((1UL << shift_u32) - divisor) << 32;
uint32_t multiplier_u32 = tmp_u64 / divisor + 1;
return make_tuple(multiplier_u32, shift_u32);
}
template <auto Divisor, typename = std::enable_if_t<(0 < Divisor)>>
CK_TILE_HOST_DEVICE static constexpr auto calculate_magic_numbers(constant<Divisor>)
{
constexpr auto tmp = calculate_magic_numbers(uint32_t{Divisor});
constexpr uint32_t multiplier = tmp[number<0>{}];
constexpr uint32_t shift = tmp[number<1>{}];
return make_tuple(constant<multiplier>{}, constant<shift>{});
}
// magic division for uint32_t
CK_TILE_DEVICE static constexpr uint32_t
do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
{
uint32_t tmp = __umulhi(dividend, multiplier);
return (tmp + dividend) >> shift;
}
CK_TILE_HOST static constexpr uint32_t
do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
{
uint32_t tmp = (static_cast<uint64_t>(dividend) * multiplier) >> 32;
return (tmp + dividend) >> shift;
}
// magic division for int32_t
// HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be
// non-negative for result to be correct
// TODO: figure out how to do magic number divison for int32_t as dividended
CK_TILE_DEVICE static constexpr int32_t
do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = __umulhi(dividend_u32, multiplier);
return (tmp + dividend_u32) >> shift;
}
CK_TILE_HOST static constexpr int32_t
do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = (static_cast<uint64_t>(dividend_u32) * multiplier) >> 32;
return (tmp + dividend_u32) >> shift;
}
};
// magic number division
// This version on works for divisor and dividended between [0, 1 << 16]
struct magic_division16_bit_range
{
// uint32_t
CK_TILE_HOST_DEVICE static constexpr auto calculate_magic_numbers(uint32_t divisor)
{
// WARNING: magic division is only valid for division inside this range.
// assert(divisor >= 1 && divisor <= (1U << 16));
uint32_t shift_u32 = 0;
while((1U << shift_u32) < divisor)
{
shift_u32++;
};
uint32_t one = 1;
uint32_t multiplier_u32 = ((one << 16) * ((one << shift_u32) - divisor)) / divisor + 1;
return make_tuple(multiplier_u32, shift_u32);
}
// integral_constant<uint32_t, .>
template <auto Divisor>
CK_TILE_HOST_DEVICE static constexpr auto calculate_magic_numbers(constant<Divisor>)
{
constexpr auto tmp = calculate_magic_numbers(uint32_t{Divisor});
constexpr uint32_t multiplier = tmp[number<0>{}];
constexpr uint32_t shift = tmp[number<1>{}];
return make_tuple(constant<multiplier>{}, constant<shift>{});
}
// magic division for uint32_t
CK_TILE_DEVICE static constexpr uint32_t
do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
{
uint32_t tmp = (dividend * multiplier) >> 16;
return (tmp + dividend) >> shift;
}
CK_TILE_HOST static constexpr uint32_t
do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
{
uint32_t tmp = (dividend * multiplier) >> 16;
return (tmp + dividend) >> shift;
}
// magic division for int32_t
// HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be
// non-negative for result to be correct
// TODO: figure out how to do magic number divison for int32_t as dividended
CK_TILE_DEVICE static constexpr int32_t
do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = (dividend_u32 * multiplier) >> 16;
return (tmp + dividend_u32) >> shift;
}
CK_TILE_HOST static constexpr int32_t
do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = (dividend_u32 * multiplier) >> 16;
return (tmp + dividend_u32) >> shift;
}
};
// use 32bit version
using magic_division = magic_division32_bit_range;
struct mdiv
{
// 1 dword -> 3 dword storage
uint32_t divisor;
uint32_t multiplier;
uint32_t shift; // TODO: 8 bit is enough
// prefer construct on host
CK_TILE_HOST_DEVICE mdiv(uint32_t divisor_) : divisor(divisor_)
{
auto tmp = magic_division::calculate_magic_numbers(divisor_);
multiplier = tmp[number<0>{}];
shift = tmp[number<1>{}];
}
CK_TILE_HOST_DEVICE mdiv() : divisor(0), multiplier(0), shift(0) {}
CK_TILE_HOST_DEVICE void update(uint32_t divisor_)
{
divisor = divisor_;
auto tmp = magic_division::calculate_magic_numbers(divisor_);
multiplier = tmp[number<0>{}];
shift = tmp[number<1>{}];
}
CK_TILE_HOST_DEVICE uint32_t div(uint32_t dividend_) const
{
return magic_division::do_magic_division(dividend_, multiplier, shift);
}
CK_TILE_HOST_DEVICE void
divmod(uint32_t dividend_, uint32_t& quotient_, uint32_t& remainder_) const
{
quotient_ = div(dividend_);
remainder_ = dividend_ - (quotient_ * divisor);
}
CK_TILE_HOST_DEVICE uint32_t get() const { return divisor; }
};
struct mdiv2
{
// 1 dword -> 2 dword storage, divisor need compute from runtime
uint32_t multiplier;
uint32_t shift; // TODO: 8 bit is enough
// prefer construct on host
CK_TILE_HOST_DEVICE mdiv2(uint32_t divisor_)
{
auto tmp = magic_division::calculate_magic_numbers(divisor_);
multiplier = tmp[number<0>{}];
shift = tmp[number<1>{}];
}
CK_TILE_HOST_DEVICE mdiv2() : multiplier(0), shift(0) {}
CK_TILE_HOST_DEVICE uint32_t div(uint32_t dividend_) const
{
return magic_division::do_magic_division(dividend_, multiplier, shift);
}
CK_TILE_HOST_DEVICE void
divmod(uint32_t dividend_, uint32_t divisor_, uint32_t& quotient_, uint32_t& remainder_) const
{
quotient_ = div(dividend_);
remainder_ = dividend_ - (quotient_ * divisor_);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include <stdint.h>
#include <tuple>
#include <type_traits>
namespace ck_tile {
// return 0 if data is not fp16 or fp32
template <typename T, uint32_t seed_>
struct prand_generator_t
{
CK_TILE_HOST_DEVICE uint32_t operator()(int, T, uint32_t = seed_) { return 0; }
};
// version for fp32
template <uint32_t seed_>
struct prand_generator_t<float, seed_>
{
CK_TILE_HOST_DEVICE uint32_t operator()(int id, float val, uint32_t seed = seed_)
{
uint32_t x = *(reinterpret_cast<uint32_t*>(&val));
uint32_t drop_bits = uint32_t(x) & 0xFFFFu;
drop_bits ^= x >> 16;
drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5);
drop_bits *= 0x7000149;
// NOTE: If id is in 64 bit, we are only using lower 32 bit.
// So, it can have an effect of using same id for multiple elements when the id is
// very large!
uint32_t rng = (drop_bits ^ 0x13371337 ^ (id * 229791) ^ seed);
return rng;
}
};
// version for fp16
template <uint32_t seed_>
struct prand_generator_t<half_t, seed_>
{
CK_TILE_HOST_DEVICE uint32_t operator()(int id, half_t val, uint32_t seed = seed_)
{
uint16_t x = *(reinterpret_cast<uint16_t*>(&val));
uint32_t drop_bits = uint32_t(x) & 0xFFFFu;
drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5);
drop_bits *= 0x7000149;
// NOTE: If id is in 64 bit, we are only using lower 32 bit.
// So, it can have an effect of using same id for multiple elements when the id is
// very large!
uint32_t rng = (drop_bits ^ 0x13371337 ^ (id * 229791) ^ seed);
return rng;
}
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment