Unverified Commit 63bc96e3 authored by Po Yen Chen's avatar Po Yen Chen Committed by GitHub
Browse files

Pre-compute coordinates to speed up store_tile() for TileWindowWithStaticDistribution<> (#12)



* Extract store_tile() logics as method

* Extract load_tile() logics as method

* Rename type alias

* Extract common logics as traits

* Remove unnecessary access specifier

* Add ComputeMode for TileWindowWithStaticDistribution

* Put field check into Traits

* More definition of Traits types

* Use more clear static_assert() message

* Enable pre-compute coordinates in store_tile()

* Re-formate static assert

* Undo changes to the wrong method

* Enable pre-compute coords for store_tile()

* Remove static_vector usage

* Add method to move non-member coordinates

* Force using pre-computed coordinates in Store()

* Fix wrong access for SFC_Ys

* Change comment

* Allow users to hint # access per coord

* Add comment for noting remove data members later

* Unify FIXME comments

* Replace FIXME comments by TODO

* Let user specify HintNumCoords

* clean

* clean

* clean

* clean

* refactor load/store for window

* clean

* clean

* bug fix for window; clean

---------
Co-authored-by: default avatarChao Liu <chao.liu2@amd.com>
parent 7337ec25
...@@ -238,16 +238,16 @@ struct GemmSoftmaxGemmImpl ...@@ -238,16 +238,16 @@ struct GemmSoftmaxGemmImpl
using SBlockTileType = decltype(tile_elementwise_in( using SBlockTileType = decltype(tile_elementwise_in(
type_convert<SMPLComputeDataType, SaccDataType>, SaccBlockTileType{})); type_convert<SMPLComputeDataType, SaccDataType>, SaccBlockTileType{}));
using PBlockTileType = decltype( using PBlockTileType = decltype(tile_elementwise_in(type_convert<PDataType, SaccDataType>,
tile_elementwise_in(type_convert<PDataType, SaccDataType>, SaccBlockTileType{})); SaccBlockTileType{}));
using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>( using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
SBlockTileType{}, Sequence<1>{}, f_max, SMPLComputeDataType{0})); SBlockTileType{}, Sequence<1>{}, f_max, SMPLComputeDataType{0}));
using OaccBlockTileType = decltype( using OaccBlockTileType = decltype(gemm1(
gemm1(get_slice_tile( get_slice_tile(
PBlockTileType{}, Sequence<0, 0>{}, Sequence<kM0PerBlock, kK1PerBlock>{}), PBlockTileType{}, Sequence<0, 0>{}, Sequence<kM0PerBlock, kK1PerBlock>{}),
v_dram_window)); v_dram_window));
// init Oacc, M, L // init Oacc, M, L
auto o_acc = OaccBlockTileType{}; auto o_acc = OaccBlockTileType{};
...@@ -322,7 +322,7 @@ struct GemmSoftmaxGemmImpl ...@@ -322,7 +322,7 @@ struct GemmSoftmaxGemmImpl
constexpr auto i_j_idx = make_tuple(idx0, idx1); constexpr auto i_j_idx = make_tuple(idx0, idx1);
// FIXME: this use different equation from FA v2 paper, // FIXME: this use different equation from FA v2 paper,
// but produce correc result. // but produce correct result.
// Is the equation wrong? // Is the equation wrong?
o_acc(i_j_idx) *= tmp; o_acc(i_j_idx) *= tmp;
}); });
...@@ -336,6 +336,7 @@ struct GemmSoftmaxGemmImpl ...@@ -336,6 +336,7 @@ struct GemmSoftmaxGemmImpl
const auto p = const auto p =
tile_elementwise_in(type_convert<PDataType, SMPLComputeDataType>, p_compute); tile_elementwise_in(type_convert<PDataType, SMPLComputeDataType>, p_compute);
// Oacc{j}
constexpr index_t k1_loops = kN0PerBlock / kK1PerBlock; constexpr index_t k1_loops = kN0PerBlock / kK1PerBlock;
if constexpr(k1_loops > 1) if constexpr(k1_loops > 1)
...@@ -369,7 +370,7 @@ struct GemmSoftmaxGemmImpl ...@@ -369,7 +370,7 @@ struct GemmSoftmaxGemmImpl
} while(iN0 < N0); } while(iN0 < N0);
// O // Oacc
constexpr auto o_spans = decltype(o_acc)::GetDistributedSpans(); constexpr auto o_spans = decltype(o_acc)::GetDistributedSpans();
sweep_tile_span(o_spans[I0], [&](auto idx0) { sweep_tile_span(o_spans[I0], [&](auto idx0) {
......
...@@ -48,18 +48,18 @@ struct SpaceFillingCurve ...@@ -48,18 +48,18 @@ struct SpaceFillingCurve
ScalarPerVector; ScalarPerVector;
} }
template <index_t AccessIdx1dBegin, index_t AccessIdx1dEnd> template <index_t AccessIdx1dHead, index_t AccessIdx1dTail>
static __device__ __host__ constexpr auto GetStepBetween(Number<AccessIdx1dBegin>, static __device__ __host__ constexpr auto GetStepBetween(Number<AccessIdx1dHead>,
Number<AccessIdx1dEnd>) Number<AccessIdx1dTail>)
{ {
static_assert(AccessIdx1dBegin >= 0, "1D index should be non-negative"); static_assert(AccessIdx1dHead >= 0 && AccessIdx1dHead < GetNumOfAccess(),
static_assert(AccessIdx1dBegin < GetNumOfAccess(), "1D index should be larger than 0"); "1D index out of range");
static_assert(AccessIdx1dEnd >= 0, "1D index should be non-negative"); static_assert(AccessIdx1dTail >= 0 && AccessIdx1dTail < GetNumOfAccess(),
static_assert(AccessIdx1dEnd < GetNumOfAccess(), "1D index should be larger than 0"); "1D index out of range");
constexpr auto idx_begin = GetIndex(Number<AccessIdx1dBegin>{}); constexpr auto idx_head = GetIndex(Number<AccessIdx1dHead>{});
constexpr auto idx_end = GetIndex(Number<AccessIdx1dEnd>{}); constexpr auto idx_tail = GetIndex(Number<AccessIdx1dTail>{});
return idx_end - idx_begin; return idx_tail - idx_head;
} }
template <index_t AccessIdx1d> template <index_t AccessIdx1d>
......
...@@ -16,306 +16,16 @@ ...@@ -16,306 +16,16 @@
namespace ck { namespace ck {
namespace tile_program { namespace tile_program {
// detail used by tile-programming APIs(), not supposed to be used directly
namespace detail {
// TODO: deprecate
// "Y dimension": Y dimensions inside TileWindowWithStaticDistribution
// input:
// y_slice_origin: starting slice origin of Y dimension
// y_slice_lengths: slice lengths of Y dimensionr
// output:
// A StaticBuffer holding slice of thread data, and data layout is hardcoded to be in the order of
// [Y0, Y1, Y2, ...]
template <typename BottomTensorView_, template <typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
typename YIndex, index_t NumCoord>
index_t... YSliceLengths> __device__ auto load_tile(TileWindowWithStaticDistribution<BottomTensorView_,
__device__ auto load_sliced_thread_data_from_tile_window( WindowLengths_,
TileWindowWithStaticDistribution<BottomTensorView_, WindowLengths_, TileDistribution_>& TileDistribution_,
tile_window, NumCoord>& tile_window)
const YIndex& ys_slice_origin,
Sequence<YSliceLengths...>)
{
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
using BottomTensorView = remove_cvref_t<BottomTensorView_>;
using WindowLengths = remove_cvref_t<WindowLengths_>;
using TileDstr = remove_cvref_t<TileDistribution_>;
using TileWindow = TileWindowWithStaticDistribution<BottomTensorView, WindowLengths, TileDstr>;
constexpr auto tile_dstr = TileDstr{};
constexpr index_t NDimP = TileDstr::GetNumOfDimensionP();
constexpr index_t NDimY = TileDstr::GetNumOfDimensionY();
static_assert(NDimY == YIndex::Size() && NDimY == sizeof...(YSliceLengths),
"wrong! inconsistent # of dimension");
static_assert(TileWindow::HasStaticTileDistribution(),
"wrong! assume static tile distribution");
constexpr auto y_slice_lengths = Sequence<YSliceLengths...>{};
constexpr index_t thread_element_size =
container_reduce(y_slice_lengths, math::multiplies{}, 1);
StaticBuffer<AddressSpaceEnum::Vgpr, DataType, thread_element_size, true> thread_buf;
constexpr auto tmp = [&y_slice_lengths]() {
const auto [ys_vector_lengths, ys_vector_strides] =
TileWindow::GetWindowAdaptorYsSafeVectorLengthStrides();
index_t VectorDimY = 0;
index_t ScalarPerVector = 1;
for(index_t i = 0; i < NDimY; ++i)
{
if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector)
{
ScalarPerVector = math::gcd(ys_vector_lengths[i], y_slice_lengths[i]);
VectorDimY = i;
}
}
return make_tuple(VectorDimY, ScalarPerVector);
}();
constexpr index_t VectorDimY = tmp.template At<0>();
constexpr index_t ScalarPerVector = tmp.template At<1>();
// FIXME
using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
constexpr auto scalars_per_access_arr = generate_array(
[&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, Number<NDimY>{});
constexpr auto scalars_per_access = TO_SEQUENCE(scalars_per_access_arr, NDimY);
using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
using vector_t = typename vector_type_t::type;
using SFC_Ys =
SpaceFillingCurve<decltype(y_slice_lengths), DimAccessOrder, decltype(scalars_per_access)>;
constexpr index_t num_access = SFC_Ys::GetNumOfAccess();
static_assert(num_access > 0, "wrong! num_access should be larger than 0");
// move to slice origin
const auto ps_ys_slice_origin = container_concat(Array<index_t, NDimP>{0}, ys_slice_origin);
tile_window.MoveWindowAdaptorAndBottomTensorThreadCoordinate(ps_ys_slice_origin);
// loop over thread tensor space [y0, y1, ...]
static_for<0, num_access, 1>{}([&](auto iAccess) {
// read from bottom tensor
const vector_t vec_value =
tile_window.GetBottomTensorView().template GetVectorizedElements<vector_t>(
tile_window.GetBottomTensorThreadCoordinate());
const vector_type_t vec{vec_value};
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::GetIndex(iAccess);
// write into distributed tensor
static_for<0, ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
[&](auto jj) {
return jj == VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
},
Number<NDimY>{});
constexpr index_t d = tile_dstr.GetYs2DDescriptor().CalculateOffset(idx_ys);
thread_buf.template At<d>() = vec.template AsType<DataType>()[j];
});
// move thread coordinate
if constexpr(iAccess.value != num_access - 1)
{
constexpr auto idx_diff_ys = SFC_Ys::GetForwardStep(iAccess);
constexpr auto idx_diff_ps_ys = container_concat(Array<index_t, NDimP>{0}, idx_diff_ys);
tile_window.MoveWindowAdaptorAndBottomTensorThreadCoordinate(idx_diff_ps_ys);
}
});
// move thread coordinate back to origin
{
constexpr auto idx_diff_ys = SFC_Ys::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
constexpr auto idx_diff_ps_ys = container_concat(Array<index_t, NDimP>{0}, idx_diff_ys);
tile_window.MoveWindowAdaptorAndBottomTensorThreadCoordinate(idx_diff_ps_ys);
}
// move back to origin
tile_window.MoveWindowAdaptorAndBottomTensorThreadCoordinate(MultiIndex<NDimP + NDimY>{0} -
ps_ys_slice_origin);
return thread_buf;
}
} // namespace detail
template <typename BottomTensorView_, typename WindowLengths_, typename TileDistribution_>
__device__ auto
load_tile(TileWindowWithStaticDistribution<BottomTensorView_, WindowLengths_, TileDistribution_>&
tile_window)
{ {
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>; return tile_window.Load();
using BottomTensorView = remove_cvref_t<BottomTensorView_>;
using WindowLengths = remove_cvref_t<WindowLengths_>;
using TileDstr = remove_cvref_t<TileDistribution_>;
using TileWindow = TileWindowWithStaticDistribution<BottomTensorView, WindowLengths, TileDstr>;
static_assert(is_known_at_compile_time<WindowLengths>::value,
"wrong! lengths should be static");
static_assert(TileWindow::HasStaticTileDistribution(), "wrong!");
constexpr auto tile_dstr = TileDstr{};
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
constexpr auto thread_tensor_lengths_ys =
to_sequence(tile_dstr.GetYs2DDescriptor().GetLengths());
constexpr index_t NDimP = TileDstr::GetNumOfDimensionP();
constexpr index_t NDimY = TileDstr::GetNumOfDimensionY();
static_assert(TileWindow::HasStaticTileDistribution(),
"wrong! assume static tile distribution");
constexpr auto tmp = [&thread_tensor_lengths_ys]() {
const auto [ys_vector_lengths, ys_vector_strides] =
TileWindow::GetWindowAdaptorYsSafeVectorLengthStrides();
index_t VectorDimY = 0;
index_t ScalarPerVector = 1;
for(index_t i = 0; i < NDimY; ++i)
{
if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector)
{
ScalarPerVector = ys_vector_lengths[i];
VectorDimY = i;
}
}
return make_tuple(VectorDimY, ScalarPerVector);
}();
constexpr index_t VectorDimY = tmp.template At<0>();
constexpr index_t ScalarPerVector = tmp.template At<1>();
// FIXME
using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
constexpr auto scalars_per_access_arr = generate_array(
[&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, Number<NDimY>{});
constexpr auto scalars_per_access = TO_SEQUENCE(scalars_per_access_arr, NDimY);
using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
using vector_t = typename vector_type_t::type;
using SFC_Ys = SpaceFillingCurve<decltype(thread_tensor_lengths_ys),
DimAccessOrder,
decltype(scalars_per_access)>;
constexpr index_t num_access = SFC_Ys::GetNumOfAccess();
static_assert(num_access > 0, "wrong! num_access should be larger than 0");
#if 1 // debug
// loop over thread tensor space [y0, y1, ...]
static_for<0, num_access, 1>{}([&](auto iAccess) {
// read from bottom tensor
const vector_t vec_value =
tile_window.GetBottomTensorView().template GetVectorizedElements<vector_t>(
tile_window.GetBottomTensorThreadCoordinate());
const vector_type_t vec{vec_value};
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::GetIndex(iAccess);
// write into distributed tensor
static_for<0, ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
[&](auto jj) {
return jj == VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
},
Number<NDimY>{});
constexpr index_t d = tile_dstr.GetYs2DDescriptor().CalculateOffset(idx_ys);
dst_tensor.GetThreadBuffer().template At<d>() = vec.template AsType<DataType>()[j];
});
// move thread coordinate
if constexpr(iAccess.value != num_access - 1)
{
constexpr auto idx_diff_ys = SFC_Ys::GetForwardStep(iAccess);
constexpr auto idx_diff_ps_ys = container_concat(Array<index_t, NDimP>{0}, idx_diff_ys);
tile_window.MoveWindowAdaptorAndBottomTensorThreadCoordinate(idx_diff_ps_ys);
}
});
// move thread coordinate back to origin
{
constexpr auto idx_diff_ys = SFC_Ys::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
constexpr auto idx_diff_ps_ys = container_concat(Array<index_t, NDimP>{0}, idx_diff_ys);
tile_window.MoveWindowAdaptorAndBottomTensorThreadCoordinate(idx_diff_ps_ys);
}
#else
auto tile_window_tmp = tile_window;
// loop over thread tensor space [y0, y1, ...]
static_for<0, num_access, 1>{}([&](auto iAccess) {
// read from bottom tensor
const vector_t vec_value =
tile_window.GetBottomTensorView().template GetVectorizedElements<vector_t>(
tile_window_tmp.GetBottomTensorThreadCoordinate());
const vector_type_t vec{vec_value};
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::GetIndex(iAccess);
// write into distributed tensor
static_for<0, ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
[&](auto jj) {
return jj == VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
},
Number<NDimY>{});
constexpr index_t d = tile_dstr.GetYs2DDescriptor().CalculateOffset(idx_ys);
dst_tensor.GetThreadBuffer().template At<d>() = vec.template AsType<DataType>()[j];
});
// move thread coordinate
if constexpr(iAccess.value != num_access - 1)
{
constexpr auto idx_diff_ys = SFC_Ys::GetForwardStep(iAccess);
constexpr auto idx_diff_ps_ys = container_concat(Array<index_t, NDimP>{0}, idx_diff_ys);
tile_window_tmp.MoveWindowAdaptorAndBottomTensorThreadCoordinate(idx_diff_ps_ys);
}
});
#endif
return dst_tensor;
} }
} // namespace tile_program } // namespace tile_program
......
...@@ -17,115 +17,15 @@ namespace tile_program { ...@@ -17,115 +17,15 @@ namespace tile_program {
template <typename BottomTensorView_, template <typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
index_t NumCoord,
typename DataType_> typename DataType_>
__device__ void __device__ void store_tile(TileWindowWithStaticDistribution<BottomTensorView_,
store_tile(TileWindowWithStaticDistribution<BottomTensorView_, WindowLengths_, TileDistribution_>& WindowLengths_,
tile_window, TileDistribution_,
const StaticDistributedTensor<DataType_, TileDistribution_>& dstr_tensor) NumCoord>& tile_window,
const StaticDistributedTensor<DataType_, TileDistribution_>& dstr_tensor)
{ {
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>; tile_window.Store(dstr_tensor);
using BottomTensorView = remove_cvref_t<BottomTensorView_>;
using WindowLengths = remove_cvref_t<WindowLengths_>;
using TileDstr = remove_cvref_t<TileDistribution_>;
using TileWindow = TileWindowWithStaticDistribution<BottomTensorView, WindowLengths, TileDstr>;
static_assert(is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
static_assert(TileWindow::HasStaticTileDistribution(), "wrong!");
constexpr auto tile_dstr = TileDstr{};
constexpr auto thread_tensor_lengths_ys =
to_sequence(tile_dstr.GetYs2DDescriptor().GetLengths());
constexpr index_t NDimP = TileDstr::GetNumOfDimensionP();
constexpr index_t NDimY = TileDstr::GetNumOfDimensionY();
constexpr auto tmp = []() {
const auto [ys_vector_lengths, ys_vector_strides] =
TileWindow::GetWindowAdaptorYsSafeVectorLengthStrides();
index_t VectorDimY = 0;
index_t ScalarPerVector = 1;
for(index_t i = 0; i < NDimY; ++i)
{
if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector)
{
ScalarPerVector = ys_vector_lengths[i];
VectorDimY = i;
}
}
return make_tuple(VectorDimY, ScalarPerVector);
}();
constexpr index_t VectorDimY = tmp.template At<0>();
constexpr index_t ScalarPerVector = tmp.template At<1>();
// FIXME:
using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
constexpr auto scalars_per_access_arr = generate_array(
[&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, Number<NDimY>{});
constexpr auto scalars_per_access = TO_SEQUENCE(scalars_per_access_arr, NDimY);
using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
using vector_t = typename vector_type_t::type;
using SFC_Ys = SpaceFillingCurve<decltype(thread_tensor_lengths_ys),
DimAccessOrder,
decltype(scalars_per_access)>;
constexpr index_t num_access = SFC_Ys::GetNumOfAccess();
static_assert(num_access > 0, "wrong! num_access should be larger than 0");
// loop over thread tensor space [y0, y1, ...]
static_for<0, num_access, 1>{}([&](auto iAccess) {
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::GetIndex(iAccess);
// read from distributed tensor
vector_type_t vec;
static_for<0, ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
[&](auto jj) {
return jj == VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
},
Number<NDimY>{});
constexpr index_t d = tile_dstr.GetYs2DDescriptor().CalculateOffset(idx_ys);
vec.template AsType<DataType>()(j) = dstr_tensor.GetThreadBuffer().template At<d>();
});
const vector_t vec_value = vec.template AsType<vector_t>().template At<0>();
// write into bottom tensor
tile_window.GetBottomTensorView().template SetVectorizedElements<vector_t>(
tile_window.GetBottomTensorThreadCoordinate(), vec_value);
// move thread coordinate
if constexpr(iAccess.value != num_access - 1)
{
constexpr auto idx_diff_ys = SFC_Ys::GetForwardStep(iAccess);
constexpr auto idx_diff_ps_ys = container_concat(Array<index_t, NDimP>{0}, idx_diff_ys);
tile_window.MoveWindowAdaptorAndBottomTensorThreadCoordinate(idx_diff_ps_ys);
}
});
// move thread coordinate back to origin
{
constexpr auto idx_diff_ys = SFC_Ys::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
constexpr auto idx_diff_ps_ys = container_concat(Array<index_t, NDimP>{0}, idx_diff_ys);
tile_window.MoveWindowAdaptorAndBottomTensorThreadCoordinate(idx_diff_ps_ys);
}
} }
} // namespace tile_program } // namespace tile_program
......
...@@ -23,11 +23,8 @@ __device__ void ...@@ -23,11 +23,8 @@ __device__ void
store_tile(TileWindowWithStaticLengths<BottomTensorView_, WindowLengths_>& tile_window_tmp, store_tile(TileWindowWithStaticLengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
const StaticDistributedTensor<DataType_, TileDistribution_>& dstr_tensor) const StaticDistributedTensor<DataType_, TileDistribution_>& dstr_tensor)
{ {
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>; using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
using BottomTensorView = remove_cvref_t<BottomTensorView_>; using TileDstr = remove_cvref_t<TileDistribution_>;
using WindowLengths = remove_cvref_t<WindowLengths_>;
using TileDstr = remove_cvref_t<TileDistribution_>;
using TileWindow = TileWindowWithStaticDistribution<BottomTensorView, WindowLengths, TileDstr>;
static_assert(is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!"); static_assert(is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
...@@ -38,98 +35,7 @@ store_tile(TileWindowWithStaticLengths<BottomTensorView_, WindowLengths_>& tile_ ...@@ -38,98 +35,7 @@ store_tile(TileWindowWithStaticLengths<BottomTensorView_, WindowLengths_>& tile_
tile_window_tmp.GetWindowOrigin(), tile_window_tmp.GetWindowOrigin(),
tile_dstr); tile_dstr);
constexpr auto thread_tensor_lengths_ys = tile_window.Store(dstr_tensor);
to_sequence(tile_dstr.GetYs2DDescriptor().GetLengths());
constexpr index_t NDimP = TileDstr::GetNumOfDimensionP();
constexpr index_t NDimY = TileDstr::GetNumOfDimensionY();
constexpr auto tmp = []() {
const auto [ys_vector_lengths, ys_vector_strides] =
TileWindow::GetWindowAdaptorYsSafeVectorLengthStrides();
index_t VectorDimY = 0;
index_t ScalarPerVector = 1;
for(index_t i = 0; i < NDimY; ++i)
{
if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector)
{
ScalarPerVector = ys_vector_lengths[i];
VectorDimY = i;
}
}
return make_tuple(VectorDimY, ScalarPerVector);
}();
constexpr index_t VectorDimY = tmp.template At<0>();
constexpr index_t ScalarPerVector = tmp.template At<1>();
// FIXME:
using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
constexpr auto scalars_per_access_arr = generate_array(
[&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, Number<NDimY>{});
constexpr auto scalars_per_access = TO_SEQUENCE(scalars_per_access_arr, NDimY);
using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
using vector_t = typename vector_type_t::type;
using SFC_Ys = SpaceFillingCurve<decltype(thread_tensor_lengths_ys),
DimAccessOrder,
decltype(scalars_per_access)>;
constexpr index_t num_access = SFC_Ys::GetNumOfAccess();
static_assert(num_access > 0, "wrong! num_access should be larger than 0");
// loop over thread tensor space [y0, y1, ...]
static_for<0, num_access, 1>{}([&](auto iAccess) {
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::GetIndex(iAccess);
// read from distributed tensor
vector_type_t vec;
static_for<0, ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
[&](auto jj) {
return jj == VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
},
Number<NDimY>{});
constexpr index_t d = tile_dstr.GetYs2DDescriptor().CalculateOffset(idx_ys);
vec.template AsType<DataType>()(j) = dstr_tensor.GetThreadBuffer().template At<d>();
});
const vector_t vec_value = vec.template AsType<vector_t>().template At<0>();
// write into bottom tensor
tile_window.GetBottomTensorView().template SetVectorizedElements<vector_t>(
tile_window.GetBottomTensorThreadCoordinate(), vec_value);
// move thread coordinate
if constexpr(iAccess.value != num_access - 1)
{
constexpr auto idx_diff_ys = SFC_Ys::GetForwardStep(iAccess);
constexpr auto idx_diff_ps_ys = container_concat(Array<index_t, NDimP>{0}, idx_diff_ys);
tile_window.MoveWindowAdaptorAndBottomTensorThreadCoordinate(idx_diff_ps_ys);
}
});
// move thread coordinate back to origin
{
constexpr auto idx_diff_ys = SFC_Ys::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
constexpr auto idx_diff_ps_ys = container_concat(Array<index_t, NDimP>{0}, idx_diff_ys);
tile_window.MoveWindowAdaptorAndBottomTensorThreadCoordinate(idx_diff_ps_ys);
}
} }
} // namespace tile_program } // namespace tile_program
......
...@@ -13,7 +13,10 @@ ...@@ -13,7 +13,10 @@
namespace ck { namespace ck {
namespace tile_program { namespace tile_program {
template <typename BottomTensorView_, typename WindowLengths_, typename StaticTileDistribution_> template <typename BottomTensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
index_t NumCoord>
struct TileWindowWithStaticDistribution struct TileWindowWithStaticDistribution
{ {
using BottomTensorView = remove_reference_t<BottomTensorView_>; using BottomTensorView = remove_reference_t<BottomTensorView_>;
...@@ -23,11 +26,17 @@ struct TileWindowWithStaticDistribution ...@@ -23,11 +26,17 @@ struct TileWindowWithStaticDistribution
using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor; using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor;
using BottomTensorDesc = typename BottomTensorView::TensorDesc; using BottomTensorDesc = typename BottomTensorView::TensorDesc;
using DataType = typename BottomTensorView::DataType; using DataType = remove_cvref_t<typename BottomTensorView::DataType>;
static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::GetNumOfTopDimension(); static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::GetNumOfTopDimension();
static constexpr index_t NDimBottomTensor = BottomTensorDesc::GetNumOfDimension(); static constexpr index_t NDimBottomTensor = BottomTensorDesc::GetNumOfDimension();
static constexpr index_t NDimP = TileDstr::GetNumOfDimensionP();
static constexpr index_t NDimY = TileDstr::GetNumOfDimensionY();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
// TODO: check WindowLengths and StaticTileDistribution are consistent // TODO: check WindowLengths and StaticTileDistribution are consistent
static_assert(is_known_at_compile_time<WindowLengths>::value, static_assert(is_known_at_compile_time<WindowLengths>::value,
...@@ -46,6 +55,73 @@ struct TileWindowWithStaticDistribution ...@@ -46,6 +55,73 @@ struct TileWindowWithStaticDistribution
using BottomTensorCoord = using BottomTensorCoord =
decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{})); decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{}));
struct LoadStoreTraits
{
private:
static constexpr auto GetVectorDimYScalarPerVector()
{
const auto [ys_vector_lengths, ys_vector_strides] =
TileWindowWithStaticDistribution::GetWindowAdaptorYsSafeVectorLengthStrides();
index_t VectorDimY_ = 0;
index_t ScalarPerVector_ = 1;
for(index_t i = 0; i < NDimY; ++i)
{
if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_)
{
ScalarPerVector_ = ys_vector_lengths[i];
VectorDimY_ = i;
}
}
return make_tuple(VectorDimY_, ScalarPerVector_);
}
public:
static constexpr index_t VectorDimY = GetVectorDimYScalarPerVector().template At<0>();
static constexpr index_t ScalarPerVector = GetVectorDimYScalarPerVector().template At<1>();
using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
using vector_t = typename vector_type_t::type;
private:
static constexpr auto scalars_per_access_ = [] {
constexpr auto scalars_per_access_arr = generate_array(
[&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, Number<NDimY>{});
/// TODO: add non-automatic storage argument support to macro TO_SEQUENCE()
constexpr auto NDimY_ = NDimY;
return TO_SEQUENCE(scalars_per_access_arr, NDimY_);
}();
static constexpr auto GetSpaceFillingCurve()
{
constexpr auto tile_dstr = TileDstr{};
constexpr auto thread_tensor_lengths_ys =
to_sequence(tile_dstr.GetYs2DDescriptor().GetLengths());
// FIXME: need logic to judge dim access order
using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
return SpaceFillingCurve<decltype(thread_tensor_lengths_ys),
DimAccessOrder,
decltype(scalars_per_access_)>{};
}
public:
using SFC_Ys = decltype(GetSpaceFillingCurve());
static constexpr index_t NumAccess = SFC_Ys::GetNumOfAccess();
static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0");
static_assert(NumAccess % NumCoord == 0, "wrong! # of access is not divisible by NumCoord");
};
static constexpr index_t NumAccessPerCoord = LoadStoreTraits::NumAccess / NumCoord;
__device__ constexpr TileWindowWithStaticDistribution() = default; __device__ constexpr TileWindowWithStaticDistribution() = default;
__device__ constexpr TileWindowWithStaticDistribution( __device__ constexpr TileWindowWithStaticDistribution(
...@@ -56,60 +132,63 @@ struct TileWindowWithStaticDistribution ...@@ -56,60 +132,63 @@ struct TileWindowWithStaticDistribution
: bottom_tensor_view_{bottom_tensor_view}, : bottom_tensor_view_{bottom_tensor_view},
window_lengths_{window_lengths}, window_lengths_{window_lengths},
window_origin_{window_origin}, window_origin_{window_origin},
bottom_tensor_thread_coord_{},
tile_dstr_{tile_distribution}, tile_dstr_{tile_distribution},
window_adaptor_thread_coord_{} pre_computed_coords_{}
{ {
#if 0 // debug #if 0 // debug
// TODO: this use more register for FA, but less register for GEMM
// need investigation
// only support warp-tile and block-tile // only support warp-tile and block-tile
static_assert(TileDstr::NDimP == 1 or TileDstr::NDimP == 2, "wrong!"); static_assert(NDimP == 1 or NDimP == 2, "wrong!");
if constexpr(TileDstr::NDimP == 1) WindowAdaptorCoord window_adaptor_thread_coord_tmp;
if constexpr(NDimP == 1)
{ {
window_adaptor_thread_coord_ = make_tensor_adaptor_coordinate( window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_distribution.GetPsYs2XsAdaptor(), AdaptorTopIndex{get_lane_id(), 0}); tile_distribution.GetPsYs2XsAdaptor(), AdaptorTopIndex{get_lane_id(), 0});
} }
else if constexpr(TileDstr::NDimP == 2) else if constexpr(NDimP == 2)
{ {
window_adaptor_thread_coord_ = window_adaptor_thread_coord_tmp =
make_tensor_adaptor_coordinate(tile_distribution.GetPsYs2XsAdaptor(), make_tensor_adaptor_coordinate(tile_distribution.GetPsYs2XsAdaptor(),
AdaptorTopIndex{get_warp_id(), get_lane_id(), 0}); AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
} }
#elif 0
// only support warp-tile and block-tile
static_assert(TileDstr::NDimP == 1 or TileDstr::NDimP == 2, "wrong!");
if constexpr(TileDstr::NDimP == 1)
{
window_adaptor_thread_coord_ = make_tensor_adaptor_coordinate(
tile_distribution.GetPsYs2XsAdaptor(),
container_concat(Array<index_t, 1>{get_lane_id()},
Array<index_t, TileDstr::NDimY>{0}));
}
else if constexpr(TileDstr::NDimP == 2)
{
window_adaptor_thread_coord_ = make_tensor_adaptor_coordinate(
tile_distribution.GetPsYs2XsAdaptor(),
container_concat(Array<index_t, 2>{get_warp_id(), get_lane_id()},
Array<index_t, TileDstr::NDimY>{0}));
}
#else #else
window_adaptor_thread_coord_ = make_tensor_adaptor_coordinate( // TODO: this use less register for FA, but more register for GEMM
// need investigation
const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_distribution.GetPsYs2XsAdaptor(), tile_distribution.GetPsYs2XsAdaptor(),
container_concat(detail::get_partition_index(tile_distribution), container_concat(detail::get_partition_index(tile_distribution),
Array<index_t, TileDstr::NDimY>{0})); Array<index_t, NDimY>{0}));
#endif #endif
BottomTensorIndex bottom_tensor_thread_origin_idx; BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
window_origin + window_adaptor_thread_coord_tmp.GetBottomIndex();
for(index_t i = 0; i < NDimBottomTensor; ++i) const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
{ bottom_tensor_view_.GetTensorDescriptor(), bottom_tensor_thread_origin_idx_tmp);
bottom_tensor_thread_origin_idx(i) =
window_origin[i] + window_adaptor_thread_coord_.GetBottomIndex()[i]; // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
} // future Load/Store() calls (might allocate more registers)
using Traits = LoadStoreTraits;
using SFC_Ys = typename Traits::SFC_Ys;
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
constexpr auto idx_diff_ys =
SFC_Ys::GetStepBetween(Number<0>{}, Number<iCoord * NumAccessPerCoord>{});
constexpr auto idx_diff_ps_ys = container_concat(Array<index_t, NDimP>{0}, idx_diff_ys);
MoveWindowAdaptorAndBottomTensorThreadCoordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
bottom_tensor_thread_coord_ = make_tensor_coordinate( pre_computed_coords_(iCoord) =
bottom_tensor_view_.GetTensorDescriptor(), bottom_tensor_thread_origin_idx); make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
});
} }
__device__ static constexpr index_t GetNumOfDimension() { return NDimBottomTensor; } __device__ static constexpr index_t GetNumOfDimension() { return NDimBottomTensor; }
...@@ -124,42 +203,22 @@ struct TileWindowWithStaticDistribution ...@@ -124,42 +203,22 @@ struct TileWindowWithStaticDistribution
__device__ constexpr auto GetWindowOrigin() const { return window_origin_; } __device__ constexpr auto GetWindowOrigin() const { return window_origin_; }
__device__ constexpr auto GetBottomTensorThreadCoordinate() const
{
return bottom_tensor_thread_coord_;
}
// move thread's window adaptor coordiante
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...]
__device__ void MoveWindowAdaptorThreadCoordinate(const AdaptorTopIndex& idx_diff_adaptor)
{
move_tensor_adaptor_coordinate(
tile_dstr_.GetPsYs2XsAdaptor(), window_adaptor_thread_coord_, idx_diff_adaptor);
}
// move thread's botom tensor coordiante
// [x0', x1', ... ] ==> [offset]
__device__ void MoveBottomTensorThreadCoordinate(const BottomTensorIndex& idx_diff_tensor)
{
move_tensor_coordinate(bottom_tensor_view_.GetTensorDescriptor(),
bottom_tensor_thread_coord_,
idx_diff_tensor);
}
// move thread's window adaptor coordinate and bottom tensor coordinate // move thread's window adaptor coordinate and bottom tensor coordinate
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset] // [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
__device__ void __device__ void MoveWindowAdaptorAndBottomTensorThreadCoordinate(
MoveWindowAdaptorAndBottomTensorThreadCoordinate(const AdaptorTopIndex& idx_diff_adaptor_top) WindowAdaptorCoord& window_adaptor_thread_coord,
BottomTensorCoord& bottom_tensor_thread_coord,
const AdaptorTopIndex& idx_diff_adaptor_top) const
{ {
Array<index_t, NDimBottomTensor> idx_diff_adaptor_bottom; Array<index_t, NDimBottomTensor> idx_diff_adaptor_bottom;
move_tensor_adaptor_coordinate(tile_dstr_.GetPsYs2XsAdaptor(), move_tensor_adaptor_coordinate(tile_dstr_.GetPsYs2XsAdaptor(),
window_adaptor_thread_coord_, window_adaptor_thread_coord,
idx_diff_adaptor_top, idx_diff_adaptor_top,
idx_diff_adaptor_bottom); idx_diff_adaptor_bottom);
move_tensor_coordinate(bottom_tensor_view_.GetTensorDescriptor(), move_tensor_coordinate(bottom_tensor_view_.GetTensorDescriptor(),
bottom_tensor_thread_coord_, bottom_tensor_thread_coord,
idx_diff_adaptor_bottom); idx_diff_adaptor_bottom);
} }
...@@ -200,6 +259,142 @@ struct TileWindowWithStaticDistribution ...@@ -200,6 +259,142 @@ struct TileWindowWithStaticDistribution
get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims)); get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims));
} }
__device__ auto Load() const
{
using Traits = LoadStoreTraits;
using vector_type_t = typename Traits::vector_type_t;
using vector_t = typename vector_type_t::type;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = Number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::GetIndex(iAccess);
// read from bottom tensor
const vector_t vec_value =
GetBottomTensorView().template GetVectorizedElements<vector_t>(
bottom_tensor_thread_coord);
const vector_type_t vec{vec_value};
// write into distributed tensor
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
},
Number<NDimY>{});
constexpr index_t d = tile_dstr.GetYs2DDescriptor().CalculateOffset(idx_ys);
dst_tensor.GetThreadBuffer().template At<d>() =
vec.template AsType<DataType>()[j];
});
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::GetForwardStep(iAccess);
constexpr auto idx_diff_ps_ys =
container_concat(Array<index_t, NDimP>{0}, idx_diff_ys);
MoveWindowAdaptorAndBottomTensorThreadCoordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
return dst_tensor;
}
__device__ void Store(const StaticDistributedTensor<DataType, TileDstr>& dstr_tensor) const
{
using Traits = LoadStoreTraits;
using vector_type_t = typename Traits::vector_type_t;
using vector_t = typename vector_type_t::type;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = Number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::GetIndex(iAccess);
// read from distributed tensor
vector_type_t vec;
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
},
Number<NDimY>{});
constexpr index_t d = tile_dstr.GetYs2DDescriptor().CalculateOffset(idx_ys);
vec.template AsType<DataType>()(j) =
dstr_tensor.GetThreadBuffer().template At<d>();
});
const vector_t vec_value = vec.template AsType<vector_t>().template At<0>();
// write into bottom tensor
GetBottomTensorView().template SetVectorizedElements<vector_t>(
bottom_tensor_thread_coord, vec_value);
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::GetForwardStep(iAccess);
constexpr auto idx_diff_ps_ys =
container_concat(Array<index_t, NDimP>{0}, idx_diff_ys);
MoveWindowAdaptorAndBottomTensorThreadCoordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
}
// move thread's botom tensor coordiante
// [x0', x1', ... ] ==> [offset]
// also move window-origin
__device__ void Move(const BottomTensorIndex& step)
{
window_origin_ += step;
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
move_tensor_coordinate(
bottom_tensor_view_.GetTensorDescriptor(), pre_computed_coords_(iCoord)(I1), step);
});
}
// this is the bottom tensor view // this is the bottom tensor view
// [x0', x1', ...] ==> [offset] // [x0', x1', ...] ==> [offset]
BottomTensorView bottom_tensor_view_; BottomTensorView bottom_tensor_view_;
...@@ -210,42 +405,51 @@ struct TileWindowWithStaticDistribution ...@@ -210,42 +405,51 @@ struct TileWindowWithStaticDistribution
// origin ([x0', x1', ...]) of window on bottom tensor // origin ([x0', x1', ...]) of window on bottom tensor
BottomTensorIndex window_origin_; BottomTensorIndex window_origin_;
// per-thread coordinate for bottom tensor
BottomTensorCoord bottom_tensor_thread_coord_;
// Tile tensor distribution, which contains: // Tile tensor distribution, which contains:
// 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] // 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...]
// 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d] // 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d]
TileDstr tile_dstr_; TileDstr tile_dstr_;
// thread window coordinate // this contains:
WindowAdaptorCoord window_adaptor_thread_coord_; // per-thread coordinate for window adaptor
// per-thread coordinate for bottom tensor
Array<Tuple<WindowAdaptorCoord, BottomTensorCoord>, NumCoord> pre_computed_coords_;
}; };
// TODO: use strategy // TODO: use strategy
template <typename TensorView_, typename WindowLengths_, typename StaticTileDistribution_> template <typename TensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
index_t NumCoord = 1>
__device__ constexpr auto __device__ constexpr auto
make_tile_window(const TensorView_& tensor_view, make_tile_window(const TensorView_& tensor_view,
const WindowLengths_& window_lengths, const WindowLengths_& window_lengths,
const MultiIndex<TensorView_::GetNumOfDimension()>& origin, const MultiIndex<TensorView_::GetNumOfDimension()>& origin,
const StaticTileDistribution_& tile_distribution) const StaticTileDistribution_& tile_distribution,
Number<NumCoord> = {})
{ {
return TileWindowWithStaticDistribution<remove_cvref_t<TensorView_>, return TileWindowWithStaticDistribution<remove_cvref_t<TensorView_>,
remove_cvref_t<WindowLengths_>, remove_cvref_t<WindowLengths_>,
remove_cvref_t<StaticTileDistribution_>>{ remove_cvref_t<StaticTileDistribution_>,
NumCoord>{
tensor_view, window_lengths, origin, tile_distribution}; tensor_view, window_lengths, origin, tile_distribution};
} }
template <typename TensorView_, typename WindowLengths_, typename StaticTileDistribution_> template <typename TensorView_,
__device__ void move_tile_window( typename WindowLengths_,
TileWindowWithStaticDistribution<TensorView_, WindowLengths_, StaticTileDistribution_>& window, typename StaticTileDistribution_,
const MultiIndex< index_t NumCoord>
TileWindowWithStaticDistribution<TensorView_, WindowLengths_, StaticTileDistribution_>:: __device__ void
GetNumOfDimension()>& step) move_tile_window(TileWindowWithStaticDistribution<TensorView_,
WindowLengths_,
StaticTileDistribution_,
NumCoord>& window,
const typename TileWindowWithStaticDistribution<TensorView_,
WindowLengths_,
StaticTileDistribution_,
NumCoord>::BottomTensorIndex& step)
{ {
window.window_origin_ += step; window.Move(step);
window.MoveBottomTensorThreadCoordinate(step);
} }
} // namespace tile_program } // namespace tile_program
......
...@@ -46,6 +46,9 @@ struct TileWindowWithStaticLengths ...@@ -46,6 +46,9 @@ struct TileWindowWithStaticLengths
__device__ constexpr auto GetWindowOrigin() const { return window_origin_; } __device__ constexpr auto GetWindowOrigin() const { return window_origin_; }
// move window-origin
__device__ void Move(const BottomTensorIndex& step) { window_origin_ += step; }
// this is the bottom tensor view // this is the bottom tensor view
// [x0', x1', ...] ==> [offset] // [x0', x1', ...] ==> [offset]
BottomTensorView bottom_tensor_view_; BottomTensorView bottom_tensor_view_;
...@@ -73,10 +76,10 @@ make_tile_window(const TensorView_& tensor_view, ...@@ -73,10 +76,10 @@ make_tile_window(const TensorView_& tensor_view,
template <typename TensorView_, typename WindowLengths_> template <typename TensorView_, typename WindowLengths_>
__device__ void move_tile_window( __device__ void move_tile_window(
TileWindowWithStaticLengths<TensorView_, WindowLengths_>& window, TileWindowWithStaticLengths<TensorView_, WindowLengths_>& window,
const MultiIndex<TileWindowWithStaticLengths<TensorView_, WindowLengths_>::GetNumOfDimension()>& const typename TileWindowWithStaticLengths<TensorView_, WindowLengths_>::BottomTensorIndex&
step) step)
{ {
window.window_origin_ += step; window.Move(step);
} }
} // namespace tile_program } // namespace tile_program
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment