Commit a4522ae3 authored by illsilin's avatar illsilin
Browse files

sync from public repo

parents 1f127242 e0594d08
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -18,6 +18,8 @@
namespace ck_tile {
// Note: this tile window do not support single issue
// you need to use tile_window_linear structure for this purpose
template <typename BottomTensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
......@@ -41,6 +43,7 @@ struct tile_window_with_static_distribution
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static_assert(NumCoord == 1);
// TODO: check WindowLengths and StaticTileDistribution are consistent
......@@ -189,7 +192,8 @@ struct tile_window_with_static_distribution
constexpr auto idx_diff_ys =
SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
constexpr auto idx_diff_ps_ys = container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}), idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
......@@ -222,10 +226,11 @@ struct tile_window_with_static_distribution
// move thread's window adaptor coordinate and bottom tensor coordinate
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
template <typename ATopIndex>
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(
WindowAdaptorCoord& window_adaptor_thread_coord,
BottomTensorCoord& bottom_tensor_thread_coord,
const AdaptorTopIndex& idx_diff_adaptor_top) const
const ATopIndex& idx_diff_adaptor_top) const
{
array<index_t, NDimBottomTensor> idx_diff_adaptor_bottom;
......@@ -279,20 +284,28 @@ struct tile_window_with_static_distribution
get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims));
}
CK_TILE_DEVICE constexpr auto get_num_access() const { return load_store_traits::NumAccess; }
CK_TILE_DEVICE constexpr auto get_num_of_access() const { return load_store_traits::NumAccess; }
template <bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(bool_constant<oob_conditional_check> = {}) const
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
using Traits = load_store_traits;
constexpr auto tile_dstr = TileDstr{};
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
load(dst_tensor, bool_constant<oob_conditional_check>{});
return dst_tensor;
}
template <typename DistributedTensor, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor,
bool_constant<oob_conditional_check> = {}) const
{
using Traits = load_store_traits;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
......@@ -308,11 +321,11 @@ struct tile_window_with_static_distribution
// read from bottom tensor
const vector_t vec_value =
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord, bool_constant<oob_conditional_check>{});
bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
#if 1
// write into distributed tensor
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
......@@ -338,20 +351,23 @@ struct tile_window_with_static_distribution
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys =
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
return dst_tensor;
}
template <typename DstTile, bool oob_conditional_check = true, bool pre_nop = false>
template <typename DstTile,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE void load_raw(DstTile& dst_tensor,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
......@@ -397,6 +413,7 @@ struct tile_window_with_static_distribution
get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
dst_vec_tbuf.template at<d / Traits::ScalarPerVector>(),
bottom_tensor_thread_coord,
0 /**/,
bool_constant<oob_conditional_check>{},
pre_nop_);
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \
......@@ -409,23 +426,24 @@ struct tile_window_with_static_distribution
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys =
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
asm volatile("; this inline asm is workaround to prevent compiler from using too much "
"scratch memory" ::);
#endif
}
// TODO: currently async load only implemented in inline asm
template <typename LdsTileWindow_, bool oob_conditional_check = true, bool pre_nop = false>
template <typename LdsTileWindow_,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
......@@ -467,7 +485,7 @@ struct tile_window_with_static_distribution
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
// TODO: use structure binding (to be captured later) if compiled in C++20
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
......@@ -482,15 +500,16 @@ struct tile_window_with_static_distribution
// read from bottom tensor
get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
smem, bottom_tensor_thread_coord, pre_nop_);
smem, bottom_tensor_thread_coord, 0, pre_nop_);
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys =
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
......@@ -501,8 +520,81 @@ struct tile_window_with_static_distribution
});
}
template <bool oob_conditional_check = true>
template <typename LdsTileWindow_,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
using LdsDataType = typename LdsTileWindow::DataType;
// issues * warps * lanes
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
// TODO: LDS offset is not good for intrinsic based implementation(compiler can't figure out
// dependency) hence avoid use offset based solution. size_per_buf should be zero (how to
// check?)
constexpr index_t size_per_buf =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<0>{}, number<0>{}));
constexpr index_t size_per_wave =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<1>{}, number<0>{})) -
size_per_buf;
constexpr index_t size_per_issue =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<1>{}, number<0>{}, number<0>{})) -
size_per_buf;
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
using Traits = load_store_traits;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
// TODO: we force CK_TILE_LDS_ADDR
CK_TILE_LDS_ADDR LdsDataType* smem =
lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + m0_init_value;
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// read from bottom tensor
get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
smem, bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
smem += size_per_issue; // Note we manually increase the per-issue offset
}
});
});
}
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
using Traits = load_store_traits;
......@@ -515,7 +607,6 @@ struct tile_window_with_static_distribution
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
......@@ -530,7 +621,7 @@ struct tile_window_with_static_distribution
vector_t vec_value;
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
......@@ -548,15 +639,19 @@ struct tile_window_with_static_distribution
// write into bottom tensor
get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
bottom_tensor_thread_coord, vec_value, bool_constant<oob_conditional_check>{});
bottom_tensor_thread_coord,
0,
vec_value,
bool_constant<oob_conditional_check>{});
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys =
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
......@@ -565,8 +660,9 @@ struct tile_window_with_static_distribution
});
}
CK_TILE_DEVICE void
store_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor) const
template <index_t i_access_unsupport_ = -1>
CK_TILE_DEVICE void store_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access_unsupport_> = {}) const
{
using Traits = load_store_traits;
......@@ -591,7 +687,7 @@ struct tile_window_with_static_distribution
// read from distributed tensor
vector_t vec_value;
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
......@@ -606,15 +702,16 @@ struct tile_window_with_static_distribution
// write into bottom tensor
get_bottom_tensor_view()
.template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
bottom_tensor_thread_coord, vec_value);
bottom_tensor_thread_coord, 0, vec_value);
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys =
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
......@@ -623,8 +720,9 @@ struct tile_window_with_static_distribution
});
}
template <bool oob_conditional_check = true>
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE void update(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
using Traits = load_store_traits;
......@@ -650,7 +748,7 @@ struct tile_window_with_static_distribution
vector_t vec_value;
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
......@@ -666,15 +764,19 @@ struct tile_window_with_static_distribution
// write into bottom tensor
get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
bottom_tensor_thread_coord, vec_value, bool_constant<oob_conditional_check>{});
bottom_tensor_thread_coord,
0,
vec_value,
bool_constant<oob_conditional_check>{});
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys =
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
......@@ -746,7 +848,8 @@ struct tile_window_with_static_distribution
constexpr auto idx_diff_ys =
SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
constexpr auto idx_diff_ps_ys = container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}), idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
......@@ -798,6 +901,27 @@ make_tile_window(const TensorView_& tensor_view,
tensor_view, window_lengths, origin, tile_distribution};
}
// this version can't be called in a constexpr context
template <typename TensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
index_t NumCoord = 1>
CK_TILE_DEVICE auto
make_tile_window_raw(const TensorView_& tensor_view,
const WindowLengths_& window_lengths,
const multi_index<TensorView_::get_num_of_dimension()>& origin,
const StaticTileDistribution_& tile_distribution,
number<NumCoord> = {})
{
auto w = tile_window_with_static_distribution<remove_cvref_t<TensorView_>,
remove_cvref_t<WindowLengths_>,
remove_cvref_t<StaticTileDistribution_>,
NumCoord>{
tensor_view, window_lengths, origin, tile_distribution};
w.init_raw();
return w;
}
template <typename TensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
......@@ -922,6 +1046,19 @@ make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths
tile_distribution);
}
template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
CK_TILE_DEVICE constexpr auto
make_tile_window_raw(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
const StaticTileDistribution& tile_distribution)
{
auto w = make_tile_window(tile_window.get_bottom_tensor_view(),
tile_window.get_window_lengths(),
tile_window.get_window_origin(),
tile_distribution);
w.init_raw();
return w;
}
template <typename TensorView_, typename WindowLengths_>
CK_TILE_DEVICE void move_tile_window(
tile_window_with_static_lengths<TensorView_, WindowLengths_>& window,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/utility.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
#define WINDOW_DISPATCH_ISSUE() \
if constexpr(i_access < 0) \
{ \
static_for<0, NumAccess, 1>{}([&](auto ia) { issue(ia); }); \
} \
else \
{ \
static_assert(i_access < NumAccess); \
issue(number<i_access>{}); \
}
//
// This version of tile window will pre-cache offset/flags based on need
//
// LinearBottomDims_, e.g seq<0, 1> for 2d tensor, the last one is linear dim
// so last dim can use immediate offset to indexing, can save register
// TODO: if using this struct, better use load_raw()/store_raw(), can control
// the the immediate offset on the fly
// space-filing-curve is non-snaked here!
//
template <typename BottomTensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
typename LinearBottomDims_>
struct tile_window_linear
{
using BottomTensorView = remove_reference_t<BottomTensorView_>;
using WindowLengths = remove_cvref_t<WindowLengths_>;
using TileDstr = remove_cvref_t<StaticTileDistribution_>;
using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor;
using BottomTensorDesc = typename BottomTensorView::TensorDesc;
using DataType = remove_cvref_t<typename BottomTensorView::DataType>;
using LinearBottomDims = remove_cvref_t<LinearBottomDims_>;
static_assert(LinearBottomDims::size() == BottomTensorView::get_num_of_dimension());
static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension();
static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension();
static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p();
static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y();
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
// TODO: check WindowLengths and StaticTileDistribution are consistent
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
"wrong! lengths should be static");
static_assert(TileDstr::is_static(), "wrong!");
static_assert(NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(),
"wrong! inconsistent # of diemsnions");
using AdaptorTopIndex = array<index_t, NDimWindowAdaptorTop>;
using BottomTensorIndex = array<index_t, NDimBottomTensor>;
using WindowAdaptorCoord =
decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{}));
using BottomTensorCoord =
decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{}));
struct traits
{
private:
// return vector dimension among [y0, y1, ...]
CK_TILE_DEVICE static constexpr auto get_window_adaptor_ys_safe_vector_length_strides()
{
// bottom tensor top dimension vector lengths and strides
const auto [bottom_tensor_top_dim_vector_lengths,
bottom_tensor_top_dim_vector_strides] =
BottomTensorDesc::get_top_dimension_safe_vector_length_strides();
// window vector lengths/strides
const auto window_adaptor_bottom_dim_vector_lengths =
bottom_tensor_top_dim_vector_lengths;
const auto window_adaptor_bottom_dim_vector_strides =
bottom_tensor_top_dim_vector_strides;
// window adaptor [p0, p1, ..., y0, y1, ...]
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()>
window_adaptor_vector_lengths{-1};
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()>
window_adaptor_vector_strides{-1};
constexpr auto window_adaptor_bottom_dims =
WindowAdaptor::get_bottom_dimension_hidden_ids();
set_container_subset(window_adaptor_vector_lengths,
window_adaptor_bottom_dims,
window_adaptor_bottom_dim_vector_lengths);
set_container_subset(window_adaptor_vector_strides,
window_adaptor_bottom_dims,
window_adaptor_bottom_dim_vector_strides);
const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] =
WindowAdaptor{}.get_top_dimension_safe_vector_length_strides(
window_adaptor_vector_lengths, window_adaptor_vector_strides);
// [y0, y1, ...]
constexpr auto y_dims =
typename arithmetic_sequence_gen<TileDstr::get_num_of_dimension_p(),
NDimWindowAdaptorTop,
1>::type{};
return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims),
get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims));
}
static constexpr auto get_vector_dim_y_scalar_per_vector()
{
const auto [ys_vector_lengths, ys_vector_strides] =
get_window_adaptor_ys_safe_vector_length_strides();
index_t VectorDimY_ = 0;
index_t ScalarPerVector_ = 1;
for(index_t i = 0; i < NDimY; ++i)
{
if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_)
{
ScalarPerVector_ = ys_vector_lengths[i];
VectorDimY_ = i;
}
}
return make_tuple(VectorDimY_, ScalarPerVector_);
}
public:
static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>();
static constexpr index_t ScalarPerVector =
get_vector_dim_y_scalar_per_vector().template at<1>();
using vector_t = thread_buffer<DataType, ScalarPerVector>;
private:
static constexpr auto scalars_per_access_ = [] {
constexpr auto scalars_per_access_arr = generate_array(
[&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number<NDimY>{});
/// TODO: add non-automatic storage argument support to macro TO_SEQUENCE()
constexpr auto NDimY_ = NDimY;
return TO_SEQUENCE(scalars_per_access_arr, NDimY_);
}();
static constexpr auto get_space_filling_curve()
{
constexpr auto thread_tensor_lengths_ys =
to_sequence(TileDstr{}.get_ys_to_d_descriptor().get_lengths());
// FIXME: need logic to judge dim access order
using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
return space_filling_curve<decltype(thread_tensor_lengths_ys),
DimAccessOrder,
decltype(scalars_per_access_),
false /*!!! no snaked curve! */>{};
}
public:
using SFC_Ys = decltype(get_space_filling_curve());
static constexpr index_t NumAccess = SFC_Ys::get_num_of_access();
static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0");
private:
static constexpr auto get_num_non_linear_access()
{
constexpr auto sfc_access_lens = SFC_Ys::access_lengths;
using ys_to_rhs_major =
typename decltype(TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor;
constexpr auto non_linear = [&]() {
index_t cnt = 1;
static_for<0, NDimY, 1>{}([&](auto i_dim_y) {
constexpr auto rhs_major = ys_to_rhs_major{}[i_dim_y];
constexpr auto target_h_dim = number<rhs_major - 1>{}; // no r dim here!
if constexpr(LinearBottomDims{}[target_h_dim] == 0)
{
cnt *= sfc_access_lens[i_dim_y];
}
});
return cnt;
}();
return non_linear;
}
// example:
// non_linear_access_map: sequence<0, 0, 0, 0, 1, 1, 1, 1> for 8 access, totally 2 register
// used
// -> histogram : sequence<4, 4>
// -> prefixsum : seqneuce<0, 4, 8>
// non_linear_access_map: sequence<0, 1, 2, 3, 4, 5, 6, 7> for 8 access, totally 8 register
// used, will pre-cache 8
// -> histogram : sequence<1, 1, 1, 1, 1, 1, 1, 1>
// -> prefixsum : seqneuce<0, 1, 2, 3, 4, 5, 6, 7, 8>
// non_linear_access_map: sequence<0, 0, 1, 1, 2, 2, 3, 3> for 8 access, totally 4 register
// used, will pre-cache 4
// -> histogram : sequence<2, 2, 2, 2>
// -> prefixsum : seqneuce<0, 2, 4, 6, 8>
static constexpr auto get_non_linear_access_map()
{
constexpr auto sfc_access_lens = SFC_Ys::access_lengths;
using ys_to_rhs_major =
typename decltype(TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor;
constexpr auto non_linear_map = [&]() {
array<index_t, NumAccess> m_{0};
index_t cumulative_len_ = 1;
index_t cumulative_non_linear_len_ = 1;
static_for<0, NDimY, 1>{}([&](auto i_y) {
constexpr auto i_dim_y = number<NDimY - i_y - 1>{}; // from right to left
constexpr auto rhs_major = ys_to_rhs_major{}[i_dim_y];
constexpr auto target_h_dim = number<rhs_major - 1>{}; // no r dim here!
constexpr auto is_linear_dim = LinearBottomDims{}[target_h_dim];
array<index_t, NumAccess> current_m_{0};
constexpr auto current_len_ = sfc_access_lens[i_dim_y];
// copy cumulative length as current pattern
for(auto i_ = 0; i_ < cumulative_len_; i_++)
{
current_m_(i_) = m_[i_];
}
for(auto j_ = 0; j_ < current_len_; j_++)
{
auto j_offset_ = is_linear_dim ? 0 : j_ * cumulative_non_linear_len_;
for(auto i_ = 0; i_ < cumulative_len_; i_++)
{
m_(j_ * cumulative_len_ + i_) = current_m_[i_] + j_offset_;
}
}
cumulative_len_ *= current_len_;
if(!is_linear_dim)
cumulative_non_linear_len_ *= current_len_;
});
return m_;
}();
return TO_SEQUENCE(non_linear_map, NumAccess);
}
static constexpr auto get_non_linear_access_histogram()
{
constexpr auto m_ = get_non_linear_access_map();
// m_.foo();
constexpr auto r_ =
typename arithmetic_sequence_gen<0, get_num_non_linear_access() + 1, 1>::type{};
constexpr auto h_ = histogram_sorted_sequence(m_, r_);
return h_;
}
static constexpr auto get_non_linear_access_histogram_prefix_sum()
{
constexpr auto h_ = get_non_linear_access_histogram();
constexpr auto h_prefix_sum_ = prefix_sum_sequence(h_);
return h_prefix_sum_;
}
public:
static constexpr index_t NumAccess_NonLinear = get_num_non_linear_access();
using AccessMap_NonLinear = decltype(get_non_linear_access_map()); // sequence
using AccessHistogram_NonLinear = decltype(get_non_linear_access_histogram());
using AccessPrefixSum_NonLinear = decltype(get_non_linear_access_histogram_prefix_sum());
};
static constexpr index_t NumAccess = traits::NumAccess;
static constexpr index_t NumAccess_NonLinear = traits::NumAccess_NonLinear;
using AccessMap_NonLinear = typename traits::AccessMap_NonLinear;
using AccessHistogram_NonLinear = typename traits::AccessHistogram_NonLinear;
using AccessPrefixSum_NonLinear = typename traits::AccessPrefixSum_NonLinear;
CK_TILE_DEVICE constexpr tile_window_linear() = default;
CK_TILE_DEVICE constexpr tile_window_linear(const BottomTensorView& bottom_tensor_view,
const WindowLengths& window_lengths,
const BottomTensorIndex& window_origin,
const TileDstr& tile_distribution)
: bottom_tensor_view_{bottom_tensor_view},
window_lengths_{window_lengths},
window_origin_{window_origin},
tile_dstr_{tile_distribution},
cached_coords_{},
cached_flags_{}
{
auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_distribution.get_ps_ys_to_xs_adaptor(),
container_concat(make_tuple(get_warp_id(), get_lane_id()),
generate_tuple([&](auto) { return number<0>{}; }, number<NDimY>{})));
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
// future load/store() calls (might allocate more registers)
using SFC_Ys = typename traits::SFC_Ys;
static_for<0, NumAccess, 1>{}([&](auto i_access) {
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[i_access]>{};
constexpr auto need_save_non_linear_coord =
bool_constant<AccessPrefixSum_NonLinear{}[non_linear_id] == i_access>{};
if constexpr(need_save_non_linear_coord)
{
cached_coords_(non_linear_id) = bottom_tensor_thread_coord_tmp;
}
// TODO: need pad_tensor_view to check which dim need use flag to check
// cached flag is independent from non-linear-coord
// but need be updated in move_tile, with proper dims
cached_flags_(i_access) = coordinate_has_valid_offset_assuming_top_index_is_valid(
bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_coord_tmp);
if constexpr(i_access != (NumAccess - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(i_access); // tuple of number
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord_tmp,
bottom_tensor_thread_coord_tmp,
idx_diff_ps_ys);
}
});
}
CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; }
CK_TILE_DEVICE static constexpr bool has_static_tile_distribution()
{
return TileDstr::is_static();
}
CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; }
CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; }
CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
CK_TILE_DEVICE constexpr void
set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data)
{
bottom_tensor_view_.buf_.p_data_ = data;
}
// move thread's window adaptor coordinate and bottom tensor coordinate
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
template <typename ATopIndex>
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(
WindowAdaptorCoord& window_adaptor_thread_coord,
BottomTensorCoord& bottom_tensor_thread_coord,
const ATopIndex& idx_diff_adaptor_top) const
{
array<index_t, NDimBottomTensor> idx_diff_adaptor_bottom;
move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
window_adaptor_thread_coord,
idx_diff_adaptor_top,
idx_diff_adaptor_bottom);
move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
bottom_tensor_thread_coord,
idx_diff_adaptor_bottom);
}
template <index_t i_access>
CK_TILE_DEVICE static constexpr auto get_bottom_linear_coordinate(number<i_access>)
{
using SFC_Ys = typename traits::SFC_Ys;
constexpr auto idx_ys = SFC_Ys::get_index(number<i_access>{});
using ys_to_rhs_major =
typename decltype(TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor;
constexpr auto modified_idx_ys = generate_tuple(
[&](auto i_dim_y) {
constexpr auto rhs_major = ys_to_rhs_major{}[i_dim_y];
constexpr auto target_h_dim = number<rhs_major - 1>{}; // no r dim here!
if constexpr(LinearBottomDims{}[target_h_dim] == 0)
{
return number<0>{};
}
else
{
return number<idx_ys[i_dim_y]>{};
}
},
number<NDimY>{});
constexpr auto adaptor_ = TileDstr{}.get_ps_ys_to_xs_adaptor();
constexpr auto idx_ =
container_concat(make_tuple(number<0>{}, number<0>{}), modified_idx_ys);
return adaptor_.calculate_bottom_index(idx_);
}
template <index_t i_access>
CK_TILE_DEVICE static constexpr index_t get_bottom_linear_offset(number<i_access>)
{
constexpr auto linear_coord = get_bottom_linear_coordinate(number<i_access>{});
// since this is linear offset, we assum bottom X tensor is always linear
constexpr index_t linear_offset = [&]() {
constexpr auto x_idx_ = linear_coord;
constexpr auto x_len_ = TileDstr{}.get_lengths();
static_assert(x_idx_.size() == x_len_.size());
constexpr index_t x_dims_ = x_idx_.size();
index_t cu_stride_ = 1;
index_t cu_offset_ = 0;
static_for<0, x_dims_, 1>{}([&](auto i_) {
auto r_i_ = number<x_dims_ - i_ - 1>{};
cu_offset_ += x_idx_[r_i_] * cu_stride_;
cu_stride_ *= x_len_[r_i_];
});
return cu_offset_;
}();
return linear_offset;
}
CK_TILE_DEVICE constexpr auto get_num_of_access() const { return traits::NumAccess; }
template <index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(number<i_access> = {}, bool_constant<oob_conditional_check> = {}) const
{
using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
auto bottom_tensor_flag = cached_flags_[IAccess];
constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
// read from bottom tensor
const vector_t vec_value =
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
linear_offset,
bottom_tensor_flag,
bool_constant<oob_conditional_check>{});
#if 1
// data index [y0, y1, ...]
constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess);
// write into distributed tensor
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == traits::VectorDimY ? (idx_diff_ys[jj] + j) : idx_diff_ys[jj];
},
number<NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
dst_tensor.get_thread_buffer().template at<d>() =
vec_value.template get_as<DataType>()[j];
});
#else
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
static_assert(d % traits::ScalarPerVector == 0);
dst_tensor.get_thread_buffer().template get_as<vector_t>()(
number<d / traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
#endif
};
WINDOW_DISPATCH_ISSUE();
return dst_tensor;
}
template <typename DstTile,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE void load_raw(DstTile& dst_tensor,
number<i_access> = {}, // negative means loop over all num_access
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys;
static constexpr index_t YElementSize =
TileDstr{}.get_ys_to_d_descriptor().get_element_space_size();
static_assert(YElementSize % traits::ScalarPerVector == 0);
using vectorized_tbuf = array<vector_t, YElementSize / traits::ScalarPerVector>;
constexpr auto tile_dstr = TileDstr{};
auto& dst_vec_tbuf = reinterpret_cast<vectorized_tbuf&>(dst_tensor.get_thread_buffer());
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
constexpr auto pre_nop_ = [&]() {
if constexpr(pre_nop && i_access_ == 0 &&
BottomTensorView::buffer_view::get_address_space() ==
address_space_enum::global)
return bool_constant<true>{};
else
return bool_constant<false>{};
}();
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
auto bottom_tensor_flag = cached_flags_[IAccess];
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess);
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
static_assert(d % traits::ScalarPerVector == 0);
get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
dst_vec_tbuf.template at<d / traits::ScalarPerVector>(),
bottom_tensor_thread_coord,
linear_offset /**/,
bottom_tensor_flag,
bool_constant<oob_conditional_check>{},
pre_nop_);
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \
CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
asm volatile(""); // this is starting from rocm-6.2, but same sympton, reuse this flag
#endif
};
WINDOW_DISPATCH_ISSUE();
}
// TODO: currently async load only implemented in inline asm
template <typename LdsTileWindow_,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
using LdsDataType = typename LdsTileWindow::DataType;
// currently we only support everything is non linear dim
// actually it's not performant if we have linear dim(e.g. fast changing)
static_assert(NumAccess_NonLinear == NumAccess);
static_assert(BottomTensorView::buffer_view::get_address_space() ==
address_space_enum::global);
// issues * warps * lanes
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
const index_t size_per_buf =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<0>{}, number<0>{})) *
sizeof(LdsDataType);
const index_t size_per_wave =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<1>{}, number<0>{})) *
sizeof(LdsDataType) -
size_per_buf;
const index_t size_per_issue =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<1>{}, number<0>{}, number<0>{})) *
sizeof(LdsDataType) -
size_per_buf;
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
m0_set_with_memory(m0_init_value); // This should be wave independent
using vector_t = typename traits::vector_t;
LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
// loop over thread tensor space [y0, y1, ...]
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
constexpr auto pre_nop_ = [&]() {
if constexpr(pre_nop && i_access_ == 0)
return bool_constant<true>{};
else
return bool_constant<false>{};
}();
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
auto bottom_tensor_flag = cached_flags_[IAccess]; // get this flag anyway
// read from bottom tensor
get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
smem, bottom_tensor_thread_coord, 0, bottom_tensor_flag, pre_nop_);
// move thread coordinate
if constexpr(i_access_ != (NumAccess - 1))
{
m0_inc_with_memory(size_per_issue);
}
};
WINDOW_DISPATCH_ISSUE();
}
template <typename LdsTileWindow_, index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) const
{
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
using LdsDataType = typename LdsTileWindow::DataType;
// currently we only support everything is non linear dim
// actually it's not performant if we have linear dim(e.g. fast changing)
static_assert(NumAccess_NonLinear == NumAccess);
static_assert(BottomTensorView::buffer_view::get_address_space() ==
address_space_enum::global);
// issues * warps * lanes
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
// TODO: LDS offset is not good for intrinsic based implementation(compiler can't figure out
// dependency) hence avoid use offset based solution. size_per_buf should be zero (how to
// check?)
constexpr index_t size_per_buf =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<0>{}, number<0>{}));
constexpr index_t size_per_wave =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<1>{}, number<0>{})) -
size_per_buf;
constexpr index_t size_per_issue =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<1>{}, number<0>{}, number<0>{})) -
size_per_buf;
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
using vector_t = typename traits::vector_t;
// TODO: we force CK_TILE_LDS_ADDR
CK_TILE_LDS_ADDR LdsDataType* smem =
lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + m0_init_value;
// loop over thread tensor space [y0, y1, ...]
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
auto bottom_tensor_flag = cached_flags_[IAccess];
// read from bottom tensor
get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
smem,
bottom_tensor_thread_coord,
0,
bottom_tensor_flag,
bool_constant<oob_conditional_check>{});
// move thread coordinate
if constexpr(i_access_ != (NumAccess - 1))
{
smem += size_per_issue; // Note we manually increase the per-issue offset
}
};
WINDOW_DISPATCH_ISSUE();
}
template <index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) const
{
using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
// loop over thread tensor space [y0, y1, ...]
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
auto bottom_tensor_flag = cached_flags_[IAccess];
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess);
// read from distributed tensor
vector_t vec_value;
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
vec_value.template get_as<DataType>()(j) =
dstr_tensor.get_thread_buffer().template at<d>();
});
// write into bottom tensor
get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
linear_offset,
bottom_tensor_flag,
vec_value,
bool_constant<oob_conditional_check>{});
};
WINDOW_DISPATCH_ISSUE();
}
template <index_t i_access = -1>
CK_TILE_DEVICE void store_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access> = {}) const
{
using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
static constexpr bool oob_conditional_check = true;
// loop over thread tensor space [y0, y1, ...]
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
auto bottom_tensor_flag = cached_flags_[IAccess];
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess);
// read from distributed tensor
vector_t vec_value;
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
vec_value.template get_as<DataType>()(j) =
dstr_tensor.get_thread_buffer().template at<d>();
});
// write into bottom tensor
get_bottom_tensor_view()
.template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
bottom_tensor_thread_coord, linear_offset, bottom_tensor_flag, vec_value);
};
WINDOW_DISPATCH_ISSUE();
}
template <index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE void update(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) const
{
using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
// loop over thread tensor space [y0, y1, ...]
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
auto bottom_tensor_flag = cached_flags_[IAccess];
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess);
// read from distributed tensor
vector_t vec_value;
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
vec_value.template get_as<DataType>()(j) =
dstr_tensor.get_thread_buffer().template at<d>();
});
// write into bottom tensor
get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
linear_offset,
bottom_tensor_flag,
vec_value,
bool_constant<oob_conditional_check>{});
};
WINDOW_DISPATCH_ISSUE();
}
// move thread's botom tensor coordiante
// [x0', x1', ... ] ==> [offset]
// also move window-origin
CK_TILE_DEVICE void move(const BottomTensorIndex& step)
{
window_origin_ += step;
static_for<0, NumAccess, 1>{}([&](auto i_access) {
constexpr auto IAccess = number<i_access>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[i_access]>{};
constexpr auto need_update_non_linear_coord =
bool_constant<AccessPrefixSum_NonLinear{}[non_linear_id] == i_access>{};
if constexpr(need_update_non_linear_coord)
{
move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
cached_coords_(non_linear_id),
step);
}
// move the current coord with linear_coords
auto tmp_coords = cached_coords_[non_linear_id];
constexpr auto linear_coord = get_bottom_linear_coordinate(IAccess);
move_tensor_coordinate(
bottom_tensor_view_.get_tensor_descriptor(), tmp_coords, linear_coord);
cached_flags_(IAccess) = coordinate_has_valid_offset_assuming_top_index_is_valid(
bottom_tensor_view_.get_tensor_descriptor(), tmp_coords);
});
}
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
{
window_origin_ = new_window_origin;
auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
TileDstr{}.get_ps_ys_to_xs_adaptor(),
container_concat(make_tuple(get_warp_id(), get_lane_id()),
generate_tuple([&](auto) { return number<0>{}; }, number<NDimY>{})));
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
// future load/store() calls (might allocate more registers)
using SFC_Ys = typename traits::SFC_Ys;
static_for<0, NumAccess, 1>{}([&](auto i_access) {
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[i_access]>{};
constexpr auto need_save_non_linear_coord =
bool_constant<AccessPrefixSum_NonLinear{}[non_linear_id] == i_access>{};
if constexpr(need_save_non_linear_coord)
{
cached_coords_(non_linear_id) = bottom_tensor_thread_coord_tmp;
}
if constexpr(i_access != (NumAccess - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(i_access); // tuple of number
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord_tmp,
bottom_tensor_thread_coord_tmp,
idx_diff_ps_ys);
}
});
}
CK_TILE_HOST_DEVICE void init_raw() { bottom_tensor_view_.init_raw(); }
// this is the bottom tensor view
// [x0', x1', ...] ==> [offset]
BottomTensorView bottom_tensor_view_;
//
WindowLengths window_lengths_;
// origin ([x0', x1', ...]) of window on bottom tensor
BottomTensorIndex window_origin_;
// Tile tensor distribution, which contains:
// 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...]
// 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d]
TileDstr tile_dstr_;
// this contains:
array<BottomTensorCoord, traits::NumAccess_NonLinear> cached_coords_;
array<bool, traits::NumAccess> cached_flags_;
};
#undef WINDOW_DISPATCH_ISSUE
namespace impl {
template <address_space_enum, index_t len_>
struct default_linear_bottom_dims_impl
{
using type = typename uniform_sequence_gen<len_, 0>::type;
};
template <index_t len_>
struct default_linear_bottom_dims_impl<address_space_enum::global, len_>
{
// global default to seq<0,0,....1>
using type = typename sequence_merge<typename uniform_sequence_gen<len_ - 1, 0>::type,
sequence<1>>::type;
};
template <index_t len_>
struct default_linear_bottom_dims_impl<address_space_enum::lds, len_>
{
// lds default to seq<1,1.....1>
using type = typename uniform_sequence_gen<len_, 1>::type;
};
} // namespace impl
template <typename TensorView_>
using default_linear_bottom_dims =
typename impl::default_linear_bottom_dims_impl<TensorView_::buffer_view::get_address_space(),
TensorView_::get_num_of_dimension()>::type;
// if using this API, will create a tile_window_linear
// this structure can have the chance to use immediate value, save register
// need pass in LinearBottomDims_ properly to control which dim is linear
// so to generate a constexpr offset as linear_offset for this dim
// (and finally pass to the immediate offset of buffer/lds instruction)
//
// Note: there is no internal check for which dim is OK to use linear offset
// user must make sure by themselves
//
// e.g.
// 2d global matrix, set LinearBottomDims_=seq<0, 1>, the last dim will generate
// immediate offset if each thread has multiple issue along last dim
//
// 2d LDS buffer, set LinearBottomDims_=seq<1, 1>, then only one vgpr used as offset
// everything else is just using immediate offset.
//
template <typename TensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
typename LinearBottomDims_ = default_linear_bottom_dims<TensorView_>>
CK_TILE_DEVICE constexpr auto
make_tile_window_linear(const TensorView_& tensor_view,
const WindowLengths_& window_lengths,
const multi_index<TensorView_::get_num_of_dimension()>& origin,
const StaticTileDistribution_& tile_distribution,
LinearBottomDims_ = {})
{
static_assert(LinearBottomDims_::size() == TensorView_::get_num_of_dimension());
return tile_window_linear<remove_cvref_t<TensorView_>,
remove_cvref_t<WindowLengths_>,
remove_cvref_t<StaticTileDistribution_>,
remove_cvref_t<LinearBottomDims_>>{
tensor_view, window_lengths, origin, tile_distribution};
}
template <
typename TileWindow_,
typename StaticTileDistribution_,
typename LinearBottomDims_ = default_linear_bottom_dims<typename TileWindow_::BottomTensorView>>
CK_TILE_DEVICE constexpr auto
make_tile_window_linear(const TileWindow_& tile_window,
const StaticTileDistribution_& tile_distribution,
LinearBottomDims_ = {})
{
return make_tile_window_linear(tile_window.get_bottom_tensor_view(),
tile_window.get_window_lengths(),
tile_window.get_window_origin(),
tile_distribution,
LinearBottomDims_{});
}
// this version must not be called under a constexpr context
template <typename TensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
typename LinearBottomDims_ = default_linear_bottom_dims<TensorView_>>
CK_TILE_DEVICE auto
make_tile_window_linear_raw(const TensorView_& tensor_view,
const WindowLengths_& window_lengths,
const multi_index<TensorView_::get_num_of_dimension()>& origin,
const StaticTileDistribution_& tile_distribution,
LinearBottomDims_ = {})
{
static_assert(LinearBottomDims_::size() == TensorView_::get_num_of_dimension());
auto w = tile_window_linear<remove_cvref_t<TensorView_>,
remove_cvref_t<WindowLengths_>,
remove_cvref_t<StaticTileDistribution_>,
remove_cvref_t<LinearBottomDims_>>{
tensor_view, window_lengths, origin, tile_distribution};
w.init_raw();
return w;
}
template <
typename TileWindow_,
typename StaticTileDistribution_,
typename LinearBottomDims_ = default_linear_bottom_dims<typename TileWindow_::BottomTensorView>>
CK_TILE_DEVICE constexpr auto
make_tile_window_linear_raw(const TileWindow_& tile_window,
const StaticTileDistribution_& tile_distribution,
LinearBottomDims_ = {})
{
return make_tile_window_linear_raw(tile_window.get_bottom_tensor_view(),
tile_window.get_window_lengths(),
tile_window.get_window_origin(),
tile_distribution,
LinearBottomDims_{});
}
template <typename TensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
typename LinearBottomDims_>
CK_TILE_DEVICE void move_tile_window(
tile_window_linear<TensorView_, WindowLengths_, StaticTileDistribution_, LinearBottomDims_>&
window,
const typename tile_window_linear<TensorView_,
WindowLengths_,
StaticTileDistribution_,
LinearBottomDims_>::BottomTensorIndex& step)
{
window.move(step);
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
// This file should not be included inside tuple.hpp!
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include <stdint.h>
#include <utility>
namespace ck_tile {
namespace detail {
// RemainLengths: sequence<...>
// Orders: sequence<...>
template <class RemainLengths, class RamainUnpacks, class Orders>
struct static_uford_impl
{
CK_TILE_HOST_DEVICE constexpr static_uford_impl()
{
static_assert(RemainLengths::size() > 0, "wrong! should not get here");
static_assert(RamainUnpacks::size() > 0, "wrong! should not get here");
}
template <class F, class CurrentUnpackIds>
CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentUnpackIds) const
{
constexpr index_t pack_len = RamainUnpacks::front();
static_for<0, RemainLengths::front(), pack_len>{}([=](auto I) {
constexpr auto new_pack = generate_tuple(
[&](auto idx_) {
constexpr auto i_new_pack = number<I + idx_ % pack_len>{};
constexpr auto i_pre_pack = number<idx_ / pack_len>{};
return CurrentUnpackIds{}.at(i_pre_pack).push_back(i_new_pack);
},
number<CurrentUnpackIds::size() * pack_len>{});
static_uford_impl<decltype(RemainLengths::pop_front()),
decltype(RamainUnpacks::pop_front()),
Orders>{}(f, new_pack);
});
}
};
template <class Orders>
struct static_uford_impl<sequence<>, sequence<>, Orders>
{
template <class F, class PackedId>
CK_TILE_HOST_DEVICE constexpr void operator()(F f, PackedId) const
{
constexpr auto origin_packs = transform_tuples(
[](auto pack_) { return decltype(pack_)::reorder_old_to_new(Orders{}); }, PackedId{});
unpack(f, origin_packs);
}
};
template <class RemainLengths, class RamainUnpacks, class Orders>
struct static_uford_one_shot_impl
{
template <class F, class CurrentUnpackIds, index_t current_acc>
CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentUnpackIds, number<current_acc>) const
{
constexpr auto r_lens_stride =
reverse_exclusive_scan_sequence(RemainLengths{}, multiplies{}, number<1>{});
constexpr auto r_upks_stride =
reverse_exclusive_scan_sequence(RamainUnpacks{}, multiplies{}, number<1>{});
constexpr index_t current_stride = r_lens_stride.front() / r_upks_stride.front();
constexpr index_t pack_len = RamainUnpacks::front();
constexpr index_t current_idx = (current_acc / current_stride) * pack_len;
constexpr auto new_pack = generate_tuple(
[&](auto idx_) {
constexpr auto i_new_pack = number<current_idx + idx_ % pack_len>{};
constexpr auto i_pre_pack = number<idx_ / pack_len>{};
return CurrentUnpackIds{}.at(i_pre_pack).push_back(i_new_pack);
},
number<CurrentUnpackIds::size() * pack_len>{});
static_uford_one_shot_impl<decltype(RemainLengths::pop_front()),
decltype(RamainUnpacks::pop_front()),
Orders>{}(f, new_pack, number<current_acc % current_stride>{});
}
};
template <class Orders>
struct static_uford_one_shot_impl<sequence<>, sequence<>, Orders>
{
template <class F, class PackedId, index_t current_acc>
CK_TILE_HOST_DEVICE constexpr void operator()(F f, PackedId, number<current_acc>) const
{
constexpr auto origin_packs = transform_tuples(
[](auto pack_) { return decltype(pack_)::reorder_old_to_new(Orders{}); }, PackedId{});
unpack(f, origin_packs);
}
};
} // namespace detail
// TODO: we may unify static_ford/static_uford in the future
//
// loop over nd space(sequence) with packs
// you must make sure the function passed in has same number of argument
//
// e.g.
// Lengths=seq<2, 3, 4>, Unpacks=<1, 1, 2>
// static_uford<Lengths, Unpacks>{}([&](auto i_0, auto i_1){}); // require 2 args(packs)
//
// loop #0, i_0=seq<0, 0, 0>, i_1=<0, 0, 1>
// loop #1, i_0=seq<0, 0, 2>, i_1=<0, 0, 3>
// loop #2, i_0=seq<0, 1, 0>, i_1=<0, 1, 1>
// loop #3, i_0=seq<0, 1, 2>, i_1=<0, 1, 3>
// loop #4, i_0=seq<0, 2, 0>, i_1=<0, 2, 1>
// loop #5, i_0=seq<0, 2, 2>, i_1=<0, 2, 3>
// loop #6, i_0=seq<1, 0, 0>, i_1=<1, 0, 1>
// ...
template <class Lengths,
class Unpacks = typename uniform_sequence_gen<Lengths::size(), 1>::type,
class Orders = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type>
struct static_uford
{
static constexpr index_t num_packs = reduce_on_sequence(Unpacks{}, multiplies{}, number<1>{});
CK_TILE_HOST_DEVICE constexpr static_uford()
{
static_assert(Lengths::size() > 0, "wrong! Lengths is empty");
static_assert(Lengths::size() == Unpacks::size(), "wrong! inconsistent size");
static_assert(Lengths::size() == Orders::size(), "wrong! inconsistent size");
static_for<0, Lengths::size(), 1>{}(
[&](auto i) { static_assert(Lengths{}.at(i) % Unpacks{}.at(i) == 0); });
}
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_access()
{
using L_ = decltype(Lengths{} / Unpacks{});
return reduce_on_sequence(L_{}, multiplies{}, number<1>{});
}
// F signature: F(sequence<...> multi_id...)
// multi_id is the unordered multi-index
template <class F>
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
{
constexpr auto ordered_lengths = Lengths::reorder_new_to_old(Orders{});
constexpr auto ordered_unpacks = Unpacks::reorder_new_to_old(Orders{});
detail::static_uford_impl<decltype(ordered_lengths), decltype(ordered_unpacks), Orders>{}(
f, make_tuple(sequence<>{}));
}
// this version is friendly for issue function one by one
template <class F, index_t i_access>
CK_TILE_HOST_DEVICE constexpr void operator()(F f, number<i_access>) const
{
static_assert(i_access < get_num_of_access());
constexpr auto ordered_lengths = Lengths::reorder_new_to_old(Orders{});
constexpr auto ordered_unpacks = Unpacks::reorder_new_to_old(Orders{});
detail::static_uford_one_shot_impl<decltype(ordered_lengths),
decltype(ordered_unpacks),
Orders>{}(
f, make_tuple(sequence<>{}), number<i_access>{});
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
namespace ck_tile {
namespace literals {
// [P0330] Literal Suffix for (signed) size_t (C++23)
// ref: https://wg21.link/p0330r8
inline constexpr std::size_t operator""_uz(unsigned long long size)
{
return static_cast<std::size_t>(size);
}
inline constexpr std::size_t operator""_zu(unsigned long long size)
{
return static_cast<std::size_t>(size);
}
} // namespace literals
} // namespace ck_tile
......@@ -59,8 +59,16 @@ struct magic_division32_bit_range
CK_TILE_DEVICE static constexpr uint32_t
do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
{
uint32_t tmp = __umulhi(dividend, multiplier);
return (tmp + dividend) >> shift;
if(__builtin_is_constant_evaluated())
{
uint32_t tmp = (static_cast<uint64_t>(dividend) * multiplier) >> 32;
return (tmp + dividend) >> shift;
}
else
{
uint32_t tmp = __umulhi(dividend, multiplier);
return (tmp + dividend) >> shift;
}
}
CK_TILE_HOST static constexpr uint32_t
......@@ -77,9 +85,18 @@ struct magic_division32_bit_range
CK_TILE_DEVICE static constexpr int32_t
do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = __umulhi(dividend_u32, multiplier);
return (tmp + dividend_u32) >> shift;
if(__builtin_is_constant_evaluated())
{
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = (static_cast<uint64_t>(dividend_u32) * multiplier) >> 32;
return (tmp + dividend_u32) >> shift;
}
else
{
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = __umulhi(dividend_u32, multiplier);
return (tmp + dividend_u32) >> shift;
}
}
CK_TILE_HOST static constexpr int32_t
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
namespace ck_tile {
namespace ReduceOp {
// y = ReduceOp(y, x);
struct Add
{
template <typename T>
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
{
return type_convert<T>(0.0f);
};
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
{
return y + x;
}
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t>>>
CK_TILE_HOST_DEVICE constexpr T operator()(T& y, T x) const
{
float y_ = type_convert<float>(y);
float x_ = type_convert<float>(x);
return type_convert<T>(y_ + x_);
}
};
struct SquareAdd
{
template <typename T>
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
{
return type_convert<T>(0.0f);
};
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
{
return y + (x * x);
}
};
struct Max
{
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
{
return numeric<T>::min();
};
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
{
return max(y, x);
}
};
struct AbsMax
{
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
{
return numeric<T>::min();
};
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
{
return max(y, abs(x));
}
};
} // namespace ReduceOp
} // namespace ck_tile
......@@ -19,10 +19,15 @@
#include "ck_tile/host/reference/reference_batched_masking.hpp"
#include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp"
#include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_elementwise.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp"
#include "ck_tile/host/reference/reference_layernorm2d.hpp"
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
#include "ck_tile/host/reference/reference_permute.hpp"
#include "ck_tile/host/reference/reference_reduce.hpp"
#include "ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp"
#include "ck_tile/host/reference/reference_rowwise_quantization2d.hpp"
#include "ck_tile/host/reference/reference_softmax.hpp"
#include "ck_tile/host/reference/reference_topk.hpp"
#include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/timer.hpp"
......@@ -10,6 +10,7 @@
#include <random>
#include <type_traits>
#include <utility>
#include <unordered_set>
#include "ck_tile/core.hpp"
......@@ -41,6 +42,73 @@ struct FillUniformDistribution
}
};
namespace impl {
// clang-format off
template<index_t bytes> struct RawIntegerType_ {};
template<> struct RawIntegerType_<1> { using type = uint8_t;};
template<> struct RawIntegerType_<2> { using type = uint16_t;};
template<> struct RawIntegerType_<4> { using type = uint32_t;};
template<> struct RawIntegerType_<8> { using type = uint64_t;};
// clang-format on
template <typename T>
using RawIntegerType = typename RawIntegerType_<sizeof(T)>::type;
} // namespace impl
// Note: this struct will have no const-ness will generate random
template <typename T>
struct FillUniformDistribution_Unique
{
float a_{-5.f};
float b_{5.f};
std::optional<uint32_t> seed_{11939};
std::mt19937 gen_{};
std::unordered_set<impl::RawIntegerType<T>> set_{};
FillUniformDistribution_Unique(float a = -5.f,
float b = 5.f,
std::optional<uint32_t> seed = {11939})
: a_(a),
b_(b),
seed_(seed),
gen_{seed_.has_value() ? *seed_ : std::random_device{}()},
set_{}
{
}
template <typename ForwardIter>
void operator()(ForwardIter first, ForwardIter last)
{
std::mt19937& gen = gen_;
std::uniform_real_distribution<float> dis(a_, b_);
auto& set = set_;
std::generate(first, last, [&dis, &gen, &set]() {
T v = static_cast<T>(0);
do
{
v = ck_tile::type_convert<T>(dis(gen));
} while(set.count(bit_cast<impl::RawIntegerType<T>>(v)) == 1);
set.insert(bit_cast<impl::RawIntegerType<T>>(v));
return v;
});
}
template <typename ForwardRange>
auto operator()(ForwardRange&& range)
-> std::void_t<decltype(std::declval<FillUniformDistribution_Unique&>()(
std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range))))>
{
(*this)(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range)));
}
void clear() { set_.clear(); }
};
template <typename T>
struct FillNormalDistribution
{
......
......@@ -11,6 +11,7 @@
#include <thread>
#include <utility>
#include <vector>
#include <functional>
#include "ck_tile/core.hpp"
#include "ck_tile/host/ranges.hpp"
......@@ -545,6 +546,28 @@ struct HostTensor
typename Data::size_type size() const { return mData.size(); }
// return a slice of this tensor
// for simplicity we just copy the data and return a new tensor
auto slice(std::vector<size_t> s_begin, std::vector<size_t> s_end) const
{
assert(s_begin.size() == s_end.size());
assert(s_begin.size() == get_num_of_dimension());
std::vector<size_t> s_len(s_begin.size());
std::transform(
s_end.begin(), s_end.end(), s_begin.begin(), s_len.begin(), std::minus<size_t>{});
HostTensor<T> sliced_tensor(s_len);
sliced_tensor.ForEach([&](auto& self, auto idx) {
std::vector<size_t> src_idx(idx.size());
std::transform(
idx.begin(), idx.end(), s_begin.begin(), src_idx.begin(), std::plus<size_t>{});
self(idx) = operator()(src_idx);
});
return sliced_tensor;
}
template <typename U = T>
auto AsSpan() const
{
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
namespace ck_tile {
template <typename ADataType, typename BDataType, typename ComputeDataType, typename ElementOp>
CK_TILE_HOST void reference_unary_elementwise(const HostTensor<ADataType>& a,
HostTensor<BDataType>& b,
ElementOp element_op)
{
// TODO: imeplement gpu version reference function
auto f = [&](auto i) {
auto v_a = type_convert<ComputeDataType>(a.mData[i]);
auto v_b = element_op(v_a);
b.mData[i] = ck_tile::type_convert<BDataType>(v_b);
};
make_ParallelTensorFunctor(f, b.get_element_space_size())(std::thread::hardware_concurrency());
}
template <typename ADataType,
typename BDataType,
typename CDataType,
typename ComputeDataType,
typename ElementOp>
CK_TILE_HOST void reference_binary_elementwise(const HostTensor<ADataType>& a,
const HostTensor<BDataType>& b,
HostTensor<CDataType>& c,
ElementOp element_op)
{
// TODO: imeplement gpu version reference function
auto f = [&](auto i) {
auto v_a = type_convert<ComputeDataType>(a.mData[i]);
auto v_b = type_convert<ComputeDataType>(b.mData[i]);
auto v_c = element_op(v_a, v_b);
c.mData[i] = ck_tile::type_convert<CDataType>(v_c);
};
make_ParallelTensorFunctor(f, c.get_element_space_size())(std::thread::hardware_concurrency());
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <thread>
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include <thread>
namespace ck_tile {
......@@ -14,55 +15,36 @@ template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC,
typename AElementOp = ck_tile::identity,
typename BElementOp = ck_tile::identity,
typename ACCElementOp = ck_tile::identity>
CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
const HostTensor<BDataType>& b_n_k,
const HostTensor<BDataType>& b_k_n,
HostTensor<CDataType>& c_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {})
{
const int N = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
? b_n_k.mDesc.get_lengths()[0]
: b_n_k.mDesc.get_lengths()[1];
const int K = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? a_m_k.mDesc.get_lengths()[1]
: a_m_k.mDesc.get_lengths()[0];
const int M = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? a_m_k.mDesc.get_lengths()[0]
: a_m_k.mDesc.get_lengths()[1];
auto f = [&](auto m) {
for(int n = 0; n < N; ++n)
const std::size_t M = a_m_k.get_length(0);
const std::size_t N = b_k_n.get_length(1);
const std::size_t K = a_m_k.get_length(1);
auto f_mn = [&](auto m, auto n) {
AccDataType v_acc = 0;
for(std::size_t k = 0; k < K; ++k)
{
AccDataType v_acc = 0;
for(int k = 0; k < K; ++k)
{
ADataType v_a = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? a_element_op(a_m_k(m, k))
: a_element_op(a_m_k(k, m));
BDataType v_b = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
? b_element_op(b_n_k(n, k))
: b_element_op(b_n_k(k, n));
v_acc += ck_tile::type_convert<AccDataType>(v_a) *
ck_tile::type_convert<AccDataType>(v_b);
}
CDataType& c_ref = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
? c_m_n(m, n)
: c_m_n(n, m);
c_ref = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
ADataType v_a = a_element_op(a_m_k(m, k));
BDataType v_b = b_element_op(b_k_n(k, n));
v_acc +=
ck_tile::type_convert<AccDataType>(v_a) * ck_tile::type_convert<AccDataType>(v_b);
}
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
};
make_ParallelTensorFunctor(f, M)(std::thread::hardware_concurrency());
make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
}
template <typename ADataType,
......
......@@ -8,20 +8,44 @@
namespace ck_tile {
// Note: for simplicity, each functor only care about single M
struct reference_layernorm2d_default_epilogue
{
template <typename OutDataType, typename AccDataType>
void operator()(int m, HostTensor<OutDataType>& o, const HostTensor<AccDataType>& acc)
{
const int N = acc.mDesc.get_lengths()[1];
for(int n = 0; n < N; ++n)
{
o(m, n) = ck_tile::type_convert<OutDataType>(acc(m, n));
}
}
template <typename OutDataType, typename AccDataType>
auto operator()(int m, const HostTensor<AccDataType>& acc)
{
HostTensor<OutDataType> o(acc.get_lengths(), acc.get_strides());
operator()(m, o, acc);
return o;
}
};
template <typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename ComputeDataType,
typename YDataType,
typename MeanDataType,
typename InvStdDataType>
typename InvStdDataType,
typename Epilogue = reference_layernorm2d_default_epilogue>
void reference_layernorm2d_fwd(const HostTensor<XDataType>& x_m_n,
const HostTensor<GammaDataType>& gamma_n,
const HostTensor<BetaDataType>& beta_n,
HostTensor<YDataType>& y_m_n,
HostTensor<MeanDataType>& mean_m,
HostTensor<InvStdDataType>& invStd_m,
ComputeDataType epsilon)
ComputeDataType epsilon,
Epilogue epilogue_functor = {})
{
auto layernorm2d_fwd_func = [&](auto m) {
const int N = x_m_n.mDesc.get_lengths()[1];
......@@ -51,16 +75,19 @@ void reference_layernorm2d_fwd(const HostTensor<XDataType>& x_m_n,
if constexpr(!std::is_same_v<InvStdDataType, ck_tile::null_type>)
invStd_m(m) = ck_tile::type_convert<InvStdDataType>(divisor);
HostTensor<ComputeDataType> acc(x_m_n.get_lengths(), x_m_n.get_strides());
for(int n = 0; n < N; ++n)
{
ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n));
ComputeDataType gamma = ck_tile::type_convert<ComputeDataType>(gamma_n(n));
ComputeDataType beta = ck_tile::type_convert<ComputeDataType>(beta_n(n));
auto y = (x - mean) * divisor;
y = y * gamma + beta;
auto a_ = (x - mean) * divisor;
a_ = a_ * gamma + beta;
y_m_n(m, n) = ck_tile::type_convert<YDataType>(y);
acc(m, n) = a_;
}
epilogue_functor(m, y_m_n, acc);
};
make_ParallelTensorFunctor(layernorm2d_fwd_func,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
#include <numeric>
#include <functional>
namespace ck_tile {
/*
this will do permute + contiguous like functionality in pytorch
*/
template <typename DataType>
CK_TILE_HOST void
reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::vector<index_t> dims)
{
const auto x_len = x.mDesc.get_lengths();
const auto y_len = y.mDesc.get_lengths();
assert(x_len.size() == y_len.size());
index_t rank = x_len.size();
const auto x_elm = std::accumulate(x_len.begin(), x_len.end(), 1, std::multiplies<index_t>());
const auto y_elm = std::accumulate(y_len.begin(), y_len.end(), 1, std::multiplies<index_t>());
assert(x_elm == y_elm);
(void)y_elm;
auto f = [&](auto i_element) {
std::vector<size_t> y_coord = [&]() {
std::vector<size_t> tmp(rank, 0);
size_t r = i_element;
for(index_t i = rank - 1; i >= 0; i--)
{
tmp[i] = r % y_len[i];
r = r / y_len[i];
}
return tmp;
}();
std::vector<size_t> x_coord = [&]() {
std::vector<size_t> tmp(rank, 0);
for(index_t i = 0; i < rank; i++)
{
tmp[dims[i]] = y_coord[i];
}
return tmp;
}();
// do permute
y(y_coord) = x(x_coord);
};
make_ParallelTensorFunctor(f, x_elm)(std::thread::hardware_concurrency());
}
} // namespace ck_tile
......@@ -9,24 +9,25 @@
namespace ck_tile {
template <typename ADataType, typename AccDataType, typename BDataType>
CK_TILE_HOST void reference_reduce(const HostTensor<ADataType>& a_m_n, HostTensor<BDataType>& b_m)
template <typename XDataType, typename ComputeDataType, typename YDataType, typename ReduceOp>
CK_TILE_HOST void
reference_reduce(const HostTensor<XDataType>& x_m_n, HostTensor<YDataType>& y_m, ReduceOp reduce_op)
{
auto f = [&](auto m) {
const int N = a_m_n.mDesc.get_lengths()[1];
const int N = x_m_n.mDesc.get_lengths()[1];
AccDataType v_acc = 0;
ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
for(int n = 0; n < N; ++n)
{
const ADataType v_a = a_m_n(m, n);
const ComputeDataType v_a = type_convert<ComputeDataType>(x_m_n(m, n));
v_acc += v_a;
v_acc = reduce_op(v_acc, v_a);
}
b_m(m) = ck_tile::type_convert<BDataType>(v_acc);
y_m(m) = ck_tile::type_convert<YDataType>(v_acc);
};
make_ParallelTensorFunctor(f, b_m.mDesc.get_lengths()[0])(std::thread::hardware_concurrency());
make_ParallelTensorFunctor(f, y_m.mDesc.get_lengths()[0])(std::thread::hardware_concurrency());
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace ck_tile {
template <typename XDataType,
typename GammaDataType,
typename ComputeDataType,
typename YDataType,
typename InvRmsDataType>
void reference_rmsnorm2d_fwd(const HostTensor<XDataType>& x_m_n,
const HostTensor<GammaDataType>& gamma_n,
HostTensor<YDataType>& y_m_n,
HostTensor<InvRmsDataType>& invRms_m,
ComputeDataType epsilon)
{
auto rmsnorm2d_fwd_func = [&](auto m) {
const int N = x_m_n.mDesc.get_lengths()[1];
ComputeDataType mean_square = 0;
ComputeDataType divisor = 0;
for(int n = 0; n < N; ++n)
{
ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n));
mean_square += x * x;
}
mean_square = mean_square / N;
divisor = ck_tile::type_convert<ComputeDataType>(1) / ck_tile::sqrt(mean_square + epsilon);
if constexpr(!std::is_same_v<InvRmsDataType, ck_tile::null_type>)
invRms_m(m) = ck_tile::type_convert<InvRmsDataType>(divisor);
for(int n = 0; n < N; ++n)
{
ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n));
ComputeDataType gamma = ck_tile::type_convert<ComputeDataType>(gamma_n(n));
auto y = x * divisor * gamma;
y_m_n(m, n) = ck_tile::type_convert<YDataType>(y);
}
};
make_ParallelTensorFunctor(rmsnorm2d_fwd_func, invRms_m.mDesc.get_lengths()[0])(
std::thread::hardware_concurrency());
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
namespace ck_tile {
template <typename XDataType, typename ScaleDataType, typename QXDataType>
CK_TILE_HOST void reference_rowwise_quantization2d(const HostTensor<XDataType>& x_m_n,
const HostTensor<ScaleDataType>& scale_m,
HostTensor<QXDataType>& qx_m_n)
{
auto f = [&](auto m) {
const int N = x_m_n.mDesc.get_lengths()[1];
for(int n = 0; n < N; ++n)
{
auto v_x = x_m_n(m, n);
// scale = amax / 127 for int8
auto v_scale = type_convert<XDataType>(scale_m(m));
auto v_qx = v_x / v_scale;
qx_m_n(m, n) = saturates<QXDataType>{}(v_qx);
}
};
make_ParallelTensorFunctor(f,
scale_m.mDesc.get_lengths()[0])(std::thread::hardware_concurrency());
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -9,43 +9,81 @@
namespace ck_tile {
template <typename ADataType, typename AccDataType, typename BDataType>
CK_TILE_HOST void reference_softmax(const HostTensor<ADataType>& a_m_n,
HostTensor<BDataType>& b_m_n)
template <typename InputType, typename ComputeType, typename OutputType = ComputeType>
CK_TILE_HOST void
reference_softmax(const HostTensor<InputType>& x, HostTensor<OutputType>& y, index_t dim = -1)
{
auto f = [&](auto m) {
const int N = a_m_n.mDesc.get_lengths()[1];
index_t rank = x.get_num_of_dimension();
assert(rank == y.get_num_of_dimension());
assert(dim == -1 || dim < rank);
AccDataType v_max = ck_tile::numeric<ADataType>::Lowest();
index_t target_dim = dim == -1 ? (rank - 1) : dim;
index_t softmax_len = x.get_length(target_dim);
index_t n_parallel = x.get_element_size() / softmax_len;
auto x_len = x.get_lengths();
// max
for(int n = 0; n < N; ++n)
{
const ADataType v_a = a_m_n(m, n);
auto f = [&](auto i_element) {
std::vector<size_t> coord = [&]() {
std::vector<size_t> t_(rank, 0);
size_t r = i_element;
for(index_t i = rank - 1; i >= 0; i--)
{
if(i == target_dim)
continue;
t_[i] = r % x_len[i];
r = r / x_len[i];
}
return t_;
}();
ComputeType v_max = -ck_tile::numeric<ComputeType>::infinity();
v_max = v_max < v_a ? v_a : v_max;
// compute max
for(auto idx = 0; idx < softmax_len; idx++)
{
auto c_ = coord;
c_[target_dim] = idx;
const ComputeType v_x = ck_tile::type_convert<ComputeType>(x(c_));
v_max = v_max < v_x ? v_x : v_max;
}
AccDataType v_exp_sum = 0;
ComputeType v_exp_sum = static_cast<ComputeType>(0);
// sum
for(int n = 0; n < N; ++n)
for(auto idx = 0; idx < softmax_len; idx++)
{
const ADataType v_a = a_m_n(m, n);
auto c_ = coord;
c_[target_dim] = idx;
v_exp_sum += ck_tile::exp(v_a - v_max);
const ComputeType v_x = ck_tile::type_convert<ComputeType>(x(c_));
v_exp_sum += ck_tile::exp(v_x - v_max);
}
// elementwise
for(int n = 0; n < N; ++n)
for(auto idx = 0; idx < softmax_len; idx++)
{
const ADataType v_a = a_m_n(m, n);
auto c_ = coord;
c_[target_dim] = idx;
const ComputeType v_x = ck_tile::type_convert<ComputeType>(x(c_));
auto out = ck_tile::exp(v_x - v_max) / v_exp_sum;
b_m_n(m, n) = ck_tile::exp(v_a - v_max) / v_exp_sum;
y(c_) = ck_tile::type_convert<OutputType>(out);
}
};
make_ParallelTensorFunctor(f,
b_m_n.mDesc.get_lengths()[0])(std::thread::hardware_concurrency());
make_ParallelTensorFunctor(f, n_parallel)(std::thread::hardware_concurrency());
}
template <typename InputType, typename ComputeType, typename OutputType = ComputeType>
CK_TILE_HOST auto reference_softmax(const HostTensor<InputType>& x, index_t dim = -1)
{
HostTensor<OutputType> y(x.get_lengths(), x.get_strides());
reference_softmax<InputType, ComputeType, OutputType>(x, y, dim);
return y;
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
#include <numeric>
#include <functional>
#include <utility>
#include <algorithm>
namespace ck_tile {
/*
similiar to torch.topk()
x (Tensor) – the input tensor.
k (int) – the k in “top-k”
dim (int, optional) – the dimension to sort along
largest (bool, optional) – largest or smallest elements
sorted (bool, optional) – elements in sorted order or not
output:
y_values
y_indices
https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/TopKImpl.h
*/
template <typename DataType, typename IndexType = index_t>
CK_TILE_HOST void reference_topk(const HostTensor<DataType>& x,
HostTensor<DataType>& y_values,
HostTensor<IndexType>& y_indices,
index_t k,
index_t dim = -1,
bool largest = true,
bool sorted = true)
{
// rank must be the same
index_t rank = x.get_num_of_dimension();
assert(rank == y_values.get_num_of_dimension());
assert(rank == y_indices.get_num_of_dimension());
assert(dim == -1 || dim < rank);
index_t topk_dim = dim == -1 ? (rank - 1) : dim;
index_t topk_src_len = x.get_length(topk_dim);
auto x_len = x.get_lengths();
assert(k <= topk_src_len);
assert(k == y_values.get_length(topk_dim) && k == y_indices.get_length(topk_dim));
index_t n_parallel = x.get_element_size() / topk_src_len;
// clang-format off
auto f = [&](auto i_element) {
std::vector<size_t> topk_coord = [&](){
std::vector<size_t> t_(rank, 0);
size_t r = i_element;
for(index_t i = rank - 1; i >= 0; i--) {
if(i == topk_dim) continue; // topk dim should be zero
t_[i] = r % x_len[i]; r = r / x_len[i];
}
return t_;
}();
using elem_t = std::pair<DataType, IndexType>;
std::vector<elem_t> q = [&](){
std::vector<elem_t> t_(topk_src_len);
for(index_t i = 0; i < topk_src_len; i++) {
auto c_ = topk_coord; c_[topk_dim] = i;
t_[i].first = x(c_); t_[i].second = i;
}
return t_;
}();
// run topk
if(largest) {
std::nth_element(q.begin(), q.begin() + k - 1, q.end(),
[](const elem_t& lhs, const elem_t& rhs) -> bool { return lhs.first > rhs.first; });
if(sorted) {
std::sort(q.begin(), q.begin() + k - 1,
[](const elem_t& lhs, const elem_t& rhs) -> bool { return lhs.first > rhs.first; });
}
} else {
std::nth_element(q.begin(), q.begin() + k - 1, q.end(),
[](const elem_t& lhs, const elem_t& rhs) -> bool { return lhs.first < rhs.first; });
if(sorted) {
std::sort(q.begin(), q.begin() + k - 1,
[](const elem_t& lhs, const elem_t& rhs) -> bool { return lhs.first < rhs.first; });
}
}
// write out
for(index_t i = 0; i < k; i++) {
auto c_ = topk_coord; c_[topk_dim] = i;
y_values(c_) = q[i].first; y_indices(c_) = q[i].second;
}
};
// clang-format on
make_ParallelTensorFunctor(f, n_parallel)(std::thread::hardware_concurrency());
}
// TODO: if using this method, the return tensor would be dense(no stride)
template <typename DataType, typename IndexType = index_t>
CK_TILE_HOST auto reference_topk(const HostTensor<DataType>& x,
index_t k,
index_t dim = -1,
bool largest = true,
bool sorted = true)
{
auto lens = x.get_lengths();
index_t target_dim = (dim == -1) ? (lens.size() - 1) : dim;
assert(target_dim < lens.size());
assert(k <= lens[target_dim]);
lens[target_dim] = k;
HostTensor<DataType> y_values(lens);
HostTensor<IndexType> y_indices(lens);
reference_topk<DataType, IndexType>(x, y_values, y_indices, k, dim, largest, sorted);
return ck_tile::make_tuple(y_values, y_indices);
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
namespace ck_tile {
// host side args
// X = A + B, Y = Rmsnorm2d(X), QY = RowwiseDynamicQuant(Y) = SaturateCast(Y / YScale)
struct AddRmsnorm2dRdquantFwdHostArgs
{
const void* p_a; // [m ,n], input, fp16/bf16
const void* p_b; // [m ,n], input, fp16/bf16
const void* p_gamma; // [1, n], gamma, prec same as input
void* p_x; // [m, n], output, p_a + p_b, fp16/bf16
void* p_yscale; // [m, 1], output, rowwise quant scale (amax / 127) of reuslt of rmsnorm2d(x)
void* p_qy; // [m, n], output, result of quant tensor of rmsnorm2d(x) int8
float epsilon;
index_t m;
index_t n;
index_t stride; // row_stride
};
// TODO: Extract some type to wrapper class
template <typename Pipeline_>
struct AddRmsnorm2dRdquantFwd
{
using Pipeline = remove_cvref_t<Pipeline_>;
using Problem = typename Pipeline::Problem;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using GammaDataType = remove_cvref_t<typename Problem::GammaDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using XDataType = remove_cvref_t<typename Problem::XDataType>;
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
using QYDataType = remove_cvref_t<typename Problem::QYDataType>;
static constexpr bool kSaveX = Problem::kSaveX;
static constexpr index_t Block_M = Problem::BlockShape::Block_M;
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
static constexpr bool kPadM = false; // always no need to pad along M
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kThreePass = Problem::kThreePass;
static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N;
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
struct Kargs
{
const void* p_a;
const void* p_b;
const void* p_gamma;
void* p_x;
void* p_yscale;
void* p_qy;
float epsilon;
index_t m;
index_t n;
index_t stride; // row_stride
};
using Hargs = AddRmsnorm2dRdquantFwdHostArgs;
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
{
return Kargs{hargs.p_a,
hargs.p_b,
hargs.p_gamma,
hargs.p_x,
hargs.p_yscale,
hargs.p_qy,
hargs.epsilon,
hargs.m,
hargs.n,
hargs.stride};
}
CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
{
return dim3(integer_divide_ceil(hargs.m, Block_M));
}
CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; }
// clang-format off
template <typename T> struct t2s;
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
// clang-format on
// in byte
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); }
CK_TILE_HOST static std::string GetName()
{
// clang-format off
using S_ = typename Problem::BlockShape;
auto surfix = [&] () {
std::string n;
if (kPadN) n += "_pn";
if (kSaveX) n += "_x";
if (kThreePass) n += "_2p";
return n; }();
#define _SS_ std::string
#define _TS_ std::to_string
return _SS_("add_rmsnorm2d_rdquant_fwd_") + _SS_(t2s<XDataType>::name) + "_" +
_TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
_TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
_SS_(Pipeline::name) + surfix;
#undef _SS_
#undef _TS_
// clang-format on
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
const auto iM = get_block_id() * Block_M;
const auto a_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const ADataType*>(kargs.p_a),
make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1),
number<Vector_N>{},
number<1>{});
const auto tmp2_ = pad_tensor_view(
tmp_, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<kPadM, kPadN>{});
return make_tile_window(
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}();
const auto b_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const BDataType*>(kargs.p_b),
make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1),
number<Vector_N>{},
number<1>{});
const auto tmp2_ = pad_tensor_view(
tmp_, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<kPadM, kPadN>{});
return make_tile_window(
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}();
const auto gamma_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const GammaDataType*>(kargs.p_gamma),
make_tuple(kargs.n),
make_tuple(1),
number<Vector_N>{},
number<1>{});
const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<Block_N>{}), sequence<kPadN>{});
return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0});
}();
auto x_window = [&]() {
if constexpr(kSaveX)
{
const auto tmp2_ = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<XDataType*>(kargs.p_x),
make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1),
number<Vector_N>{},
number<1>{});
return pad_tensor_view(tmp_,
make_tuple(number<Block_M>{}, number<Block_N>{}),
sequence<kPadM, kPadN>{});
}();
return make_tile_window(
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}
else
return make_null_tile_window(make_tuple(number<Block_M>{}, number<Block_N>{}));
}();
auto yscale_window = [&]() {
auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<YScaleDataType*>(kargs.p_yscale),
make_tuple(kargs.m),
make_tuple(1),
number<1>{});
auto tmp2_ = pad_tensor_view(tmp_, make_tuple(number<Block_M>{}), sequence<kPadM>{});
return make_tile_window(tmp2_, make_tuple(number<Block_M>{}), {iM});
}();
auto qy_window = [&]() {
auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<QYDataType*>(kargs.p_qy),
make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1),
number<Vector_N>{},
number<1>{});
auto tmp2_ = pad_tensor_view(
tmp_, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<kPadM, kPadN>{});
return make_tile_window(
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}();
__shared__ char smem[GetSmemSize()];
Pipeline{}(a_window,
b_window,
gamma_window,
x_window,
yscale_window,
qy_window,
static_cast<const ComputeDataType>(kargs.epsilon),
kargs.n,
smem);
}
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment