"...resnet50_tensorflow.git" did not exist on "cf60559f2f0fa165b77eacd5fb19c0d2cb8bfee5"
Unverified Commit 63bc96e3 authored by Po Yen Chen's avatar Po Yen Chen Committed by GitHub
Browse files

Pre-compute coordinates to speed up store_tile() for TileWindowWithStaticDistribution<> (#12)



* Extract store_tile() logics as method

* Extract load_tile() logics as method

* Rename type alias

* Extract common logics as traits

* Remove unnecessary access specifier

* Add ComputeMode for TileWindowWithStaticDistribution

* Put field check into Traits

* More definition of Traits types

* Use more clear static_assert() message

* Enable pre-compute coordinates in store_tile()

* Re-formate static assert

* Undo changes to the wrong method

* Enable pre-compute coords for store_tile()

* Remove static_vector usage

* Add method to move non-member coordinates

* Force using pre-computed coordinates in Store()

* Fix wrong access for SFC_Ys

* Change comment

* Allow users to hint # access per coord

* Add comment for noting remove data members later

* Unify FIXME comments

* Replace FIXME comments by TODO

* Let user specify HintNumCoords

* clean

* clean

* clean

* clean

* refactor load/store for window

* clean

* clean

* bug fix for window; clean

---------
Co-authored-by: default avatarChao Liu <chao.liu2@amd.com>
parent 7337ec25
...@@ -238,16 +238,16 @@ struct GemmSoftmaxGemmImpl ...@@ -238,16 +238,16 @@ struct GemmSoftmaxGemmImpl
using SBlockTileType = decltype(tile_elementwise_in( using SBlockTileType = decltype(tile_elementwise_in(
type_convert<SMPLComputeDataType, SaccDataType>, SaccBlockTileType{})); type_convert<SMPLComputeDataType, SaccDataType>, SaccBlockTileType{}));
using PBlockTileType = decltype( using PBlockTileType = decltype(tile_elementwise_in(type_convert<PDataType, SaccDataType>,
tile_elementwise_in(type_convert<PDataType, SaccDataType>, SaccBlockTileType{})); SaccBlockTileType{}));
using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>( using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
SBlockTileType{}, Sequence<1>{}, f_max, SMPLComputeDataType{0})); SBlockTileType{}, Sequence<1>{}, f_max, SMPLComputeDataType{0}));
using OaccBlockTileType = decltype( using OaccBlockTileType = decltype(gemm1(
gemm1(get_slice_tile( get_slice_tile(
PBlockTileType{}, Sequence<0, 0>{}, Sequence<kM0PerBlock, kK1PerBlock>{}), PBlockTileType{}, Sequence<0, 0>{}, Sequence<kM0PerBlock, kK1PerBlock>{}),
v_dram_window)); v_dram_window));
// init Oacc, M, L // init Oacc, M, L
auto o_acc = OaccBlockTileType{}; auto o_acc = OaccBlockTileType{};
...@@ -322,7 +322,7 @@ struct GemmSoftmaxGemmImpl ...@@ -322,7 +322,7 @@ struct GemmSoftmaxGemmImpl
constexpr auto i_j_idx = make_tuple(idx0, idx1); constexpr auto i_j_idx = make_tuple(idx0, idx1);
// FIXME: this use different equation from FA v2 paper, // FIXME: this use different equation from FA v2 paper,
// but produce correc result. // but produce correct result.
// Is the equation wrong? // Is the equation wrong?
o_acc(i_j_idx) *= tmp; o_acc(i_j_idx) *= tmp;
}); });
...@@ -336,6 +336,7 @@ struct GemmSoftmaxGemmImpl ...@@ -336,6 +336,7 @@ struct GemmSoftmaxGemmImpl
const auto p = const auto p =
tile_elementwise_in(type_convert<PDataType, SMPLComputeDataType>, p_compute); tile_elementwise_in(type_convert<PDataType, SMPLComputeDataType>, p_compute);
// Oacc{j}
constexpr index_t k1_loops = kN0PerBlock / kK1PerBlock; constexpr index_t k1_loops = kN0PerBlock / kK1PerBlock;
if constexpr(k1_loops > 1) if constexpr(k1_loops > 1)
...@@ -369,7 +370,7 @@ struct GemmSoftmaxGemmImpl ...@@ -369,7 +370,7 @@ struct GemmSoftmaxGemmImpl
} while(iN0 < N0); } while(iN0 < N0);
// O // Oacc
constexpr auto o_spans = decltype(o_acc)::GetDistributedSpans(); constexpr auto o_spans = decltype(o_acc)::GetDistributedSpans();
sweep_tile_span(o_spans[I0], [&](auto idx0) { sweep_tile_span(o_spans[I0], [&](auto idx0) {
......
...@@ -48,18 +48,18 @@ struct SpaceFillingCurve ...@@ -48,18 +48,18 @@ struct SpaceFillingCurve
ScalarPerVector; ScalarPerVector;
} }
template <index_t AccessIdx1dBegin, index_t AccessIdx1dEnd> template <index_t AccessIdx1dHead, index_t AccessIdx1dTail>
static __device__ __host__ constexpr auto GetStepBetween(Number<AccessIdx1dBegin>, static __device__ __host__ constexpr auto GetStepBetween(Number<AccessIdx1dHead>,
Number<AccessIdx1dEnd>) Number<AccessIdx1dTail>)
{ {
static_assert(AccessIdx1dBegin >= 0, "1D index should be non-negative"); static_assert(AccessIdx1dHead >= 0 && AccessIdx1dHead < GetNumOfAccess(),
static_assert(AccessIdx1dBegin < GetNumOfAccess(), "1D index should be larger than 0"); "1D index out of range");
static_assert(AccessIdx1dEnd >= 0, "1D index should be non-negative"); static_assert(AccessIdx1dTail >= 0 && AccessIdx1dTail < GetNumOfAccess(),
static_assert(AccessIdx1dEnd < GetNumOfAccess(), "1D index should be larger than 0"); "1D index out of range");
constexpr auto idx_begin = GetIndex(Number<AccessIdx1dBegin>{}); constexpr auto idx_head = GetIndex(Number<AccessIdx1dHead>{});
constexpr auto idx_end = GetIndex(Number<AccessIdx1dEnd>{}); constexpr auto idx_tail = GetIndex(Number<AccessIdx1dTail>{});
return idx_end - idx_begin; return idx_tail - idx_head;
} }
template <index_t AccessIdx1d> template <index_t AccessIdx1d>
......
...@@ -16,306 +16,16 @@ ...@@ -16,306 +16,16 @@
namespace ck { namespace ck {
namespace tile_program { namespace tile_program {
// detail used by tile-programming APIs(), not supposed to be used directly
namespace detail {
// TODO: deprecate
// "Y dimension": Y dimensions inside TileWindowWithStaticDistribution
// input:
// y_slice_origin: starting slice origin of Y dimension
// y_slice_lengths: slice lengths of Y dimensionr
// output:
// A StaticBuffer holding slice of thread data, and data layout is hardcoded to be in the order of
// [Y0, Y1, Y2, ...]
template <typename BottomTensorView_, template <typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
typename YIndex, index_t NumCoord>
index_t... YSliceLengths> __device__ auto load_tile(TileWindowWithStaticDistribution<BottomTensorView_,
__device__ auto load_sliced_thread_data_from_tile_window( WindowLengths_,
TileWindowWithStaticDistribution<BottomTensorView_, WindowLengths_, TileDistribution_>& TileDistribution_,
tile_window, NumCoord>& tile_window)
const YIndex& ys_slice_origin,
Sequence<YSliceLengths...>)
{
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
using BottomTensorView = remove_cvref_t<BottomTensorView_>;
using WindowLengths = remove_cvref_t<WindowLengths_>;
using TileDstr = remove_cvref_t<TileDistribution_>;
using TileWindow = TileWindowWithStaticDistribution<BottomTensorView, WindowLengths, TileDstr>;
constexpr auto tile_dstr = TileDstr{};
constexpr index_t NDimP = TileDstr::GetNumOfDimensionP();
constexpr index_t NDimY = TileDstr::GetNumOfDimensionY();
static_assert(NDimY == YIndex::Size() && NDimY == sizeof...(YSliceLengths),
"wrong! inconsistent # of dimension");
static_assert(TileWindow::HasStaticTileDistribution(),
"wrong! assume static tile distribution");
constexpr auto y_slice_lengths = Sequence<YSliceLengths...>{};
constexpr index_t thread_element_size =
container_reduce(y_slice_lengths, math::multiplies{}, 1);
StaticBuffer<AddressSpaceEnum::Vgpr, DataType, thread_element_size, true> thread_buf;
constexpr auto tmp = [&y_slice_lengths]() {
const auto [ys_vector_lengths, ys_vector_strides] =
TileWindow::GetWindowAdaptorYsSafeVectorLengthStrides();
index_t VectorDimY = 0;
index_t ScalarPerVector = 1;
for(index_t i = 0; i < NDimY; ++i)
{
if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector)
{
ScalarPerVector = math::gcd(ys_vector_lengths[i], y_slice_lengths[i]);
VectorDimY = i;
}
}
return make_tuple(VectorDimY, ScalarPerVector);
}();
constexpr index_t VectorDimY = tmp.template At<0>();
constexpr index_t ScalarPerVector = tmp.template At<1>();
// FIXME
using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
constexpr auto scalars_per_access_arr = generate_array(
[&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, Number<NDimY>{});
constexpr auto scalars_per_access = TO_SEQUENCE(scalars_per_access_arr, NDimY);
using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
using vector_t = typename vector_type_t::type;
using SFC_Ys =
SpaceFillingCurve<decltype(y_slice_lengths), DimAccessOrder, decltype(scalars_per_access)>;
constexpr index_t num_access = SFC_Ys::GetNumOfAccess();
static_assert(num_access > 0, "wrong! num_access should be larger than 0");
// move to slice origin
const auto ps_ys_slice_origin = container_concat(Array<index_t, NDimP>{0}, ys_slice_origin);
tile_window.MoveWindowAdaptorAndBottomTensorThreadCoordinate(ps_ys_slice_origin);
// loop over thread tensor space [y0, y1, ...]
static_for<0, num_access, 1>{}([&](auto iAccess) {
// read from bottom tensor
const vector_t vec_value =
tile_window.GetBottomTensorView().template GetVectorizedElements<vector_t>(
tile_window.GetBottomTensorThreadCoordinate());
const vector_type_t vec{vec_value};
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::GetIndex(iAccess);
// write into distributed tensor
static_for<0, ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
[&](auto jj) {
return jj == VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
},
Number<NDimY>{});
constexpr index_t d = tile_dstr.GetYs2DDescriptor().CalculateOffset(idx_ys);
thread_buf.template At<d>() = vec.template AsType<DataType>()[j];
});
// move thread coordinate
if constexpr(iAccess.value != num_access - 1)
{
constexpr auto idx_diff_ys = SFC_Ys::GetForwardStep(iAccess);
constexpr auto idx_diff_ps_ys = container_concat(Array<index_t, NDimP>{0}, idx_diff_ys);
tile_window.MoveWindowAdaptorAndBottomTensorThreadCoordinate(idx_diff_ps_ys);
}
});
// move thread coordinate back to origin
{
constexpr auto idx_diff_ys = SFC_Ys::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
constexpr auto idx_diff_ps_ys = container_concat(Array<index_t, NDimP>{0}, idx_diff_ys);
tile_window.MoveWindowAdaptorAndBottomTensorThreadCoordinate(idx_diff_ps_ys);
}
// move back to origin
tile_window.MoveWindowAdaptorAndBottomTensorThreadCoordinate(MultiIndex<NDimP + NDimY>{0} -
ps_ys_slice_origin);
return thread_buf;
}
} // namespace detail
template <typename BottomTensorView_, typename WindowLengths_, typename TileDistribution_>
__device__ auto
load_tile(TileWindowWithStaticDistribution<BottomTensorView_, WindowLengths_, TileDistribution_>&
tile_window)
{ {
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>; return tile_window.Load();
using BottomTensorView = remove_cvref_t<BottomTensorView_>;
using WindowLengths = remove_cvref_t<WindowLengths_>;
using TileDstr = remove_cvref_t<TileDistribution_>;
using TileWindow = TileWindowWithStaticDistribution<BottomTensorView, WindowLengths, TileDstr>;
static_assert(is_known_at_compile_time<WindowLengths>::value,
"wrong! lengths should be static");
static_assert(TileWindow::HasStaticTileDistribution(), "wrong!");
constexpr auto tile_dstr = TileDstr{};
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
constexpr auto thread_tensor_lengths_ys =
to_sequence(tile_dstr.GetYs2DDescriptor().GetLengths());
constexpr index_t NDimP = TileDstr::GetNumOfDimensionP();
constexpr index_t NDimY = TileDstr::GetNumOfDimensionY();
static_assert(TileWindow::HasStaticTileDistribution(),
"wrong! assume static tile distribution");
constexpr auto tmp = [&thread_tensor_lengths_ys]() {
const auto [ys_vector_lengths, ys_vector_strides] =
TileWindow::GetWindowAdaptorYsSafeVectorLengthStrides();
index_t VectorDimY = 0;
index_t ScalarPerVector = 1;
for(index_t i = 0; i < NDimY; ++i)
{
if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector)
{
ScalarPerVector = ys_vector_lengths[i];
VectorDimY = i;
}
}
return make_tuple(VectorDimY, ScalarPerVector);
}();
constexpr index_t VectorDimY = tmp.template At<0>();
constexpr index_t ScalarPerVector = tmp.template At<1>();
// FIXME
using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
constexpr auto scalars_per_access_arr = generate_array(
[&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, Number<NDimY>{});
constexpr auto scalars_per_access = TO_SEQUENCE(scalars_per_access_arr, NDimY);
using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
using vector_t = typename vector_type_t::type;
using SFC_Ys = SpaceFillingCurve<decltype(thread_tensor_lengths_ys),
DimAccessOrder,
decltype(scalars_per_access)>;
constexpr index_t num_access = SFC_Ys::GetNumOfAccess();
static_assert(num_access > 0, "wrong! num_access should be larger than 0");
#if 1 // debug
// loop over thread tensor space [y0, y1, ...]
static_for<0, num_access, 1>{}([&](auto iAccess) {
// read from bottom tensor
const vector_t vec_value =
tile_window.GetBottomTensorView().template GetVectorizedElements<vector_t>(
tile_window.GetBottomTensorThreadCoordinate());
const vector_type_t vec{vec_value};
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::GetIndex(iAccess);
// write into distributed tensor
static_for<0, ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
[&](auto jj) {
return jj == VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
},
Number<NDimY>{});
constexpr index_t d = tile_dstr.GetYs2DDescriptor().CalculateOffset(idx_ys);
dst_tensor.GetThreadBuffer().template At<d>() = vec.template AsType<DataType>()[j];
});
// move thread coordinate
if constexpr(iAccess.value != num_access - 1)
{
constexpr auto idx_diff_ys = SFC_Ys::GetForwardStep(iAccess);
constexpr auto idx_diff_ps_ys = container_concat(Array<index_t, NDimP>{0}, idx_diff_ys);
tile_window.MoveWindowAdaptorAndBottomTensorThreadCoordinate(idx_diff_ps_ys);
}
});
// move thread coordinate back to origin
{
constexpr auto idx_diff_ys = SFC_Ys::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
constexpr auto idx_diff_ps_ys = container_concat(Array<index_t, NDimP>{0}, idx_diff_ys);
tile_window.MoveWindowAdaptorAndBottomTensorThreadCoordinate(idx_diff_ps_ys);
}
#else
auto tile_window_tmp = tile_window;
// loop over thread tensor space [y0, y1, ...]
static_for<0, num_access, 1>{}([&](auto iAccess) {
// read from bottom tensor
const vector_t vec_value =
tile_window.GetBottomTensorView().template GetVectorizedElements<vector_t>(
tile_window_tmp.GetBottomTensorThreadCoordinate());
const vector_type_t vec{vec_value};
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::GetIndex(iAccess);
// write into distributed tensor
static_for<0, ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
[&](auto jj) {
return jj == VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
},
Number<NDimY>{});
constexpr index_t d = tile_dstr.GetYs2DDescriptor().CalculateOffset(idx_ys);
dst_tensor.GetThreadBuffer().template At<d>() = vec.template AsType<DataType>()[j];
});
// move thread coordinate
if constexpr(iAccess.value != num_access - 1)
{
constexpr auto idx_diff_ys = SFC_Ys::GetForwardStep(iAccess);
constexpr auto idx_diff_ps_ys = container_concat(Array<index_t, NDimP>{0}, idx_diff_ys);
tile_window_tmp.MoveWindowAdaptorAndBottomTensorThreadCoordinate(idx_diff_ps_ys);
}
});
#endif
return dst_tensor;
} }
} // namespace tile_program } // namespace tile_program
......
...@@ -17,115 +17,15 @@ namespace tile_program { ...@@ -17,115 +17,15 @@ namespace tile_program {
template <typename BottomTensorView_, template <typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
index_t NumCoord,
typename DataType_> typename DataType_>
__device__ void __device__ void store_tile(TileWindowWithStaticDistribution<BottomTensorView_,
store_tile(TileWindowWithStaticDistribution<BottomTensorView_, WindowLengths_, TileDistribution_>& WindowLengths_,
tile_window, TileDistribution_,
const StaticDistributedTensor<DataType_, TileDistribution_>& dstr_tensor) NumCoord>& tile_window,
const StaticDistributedTensor<DataType_, TileDistribution_>& dstr_tensor)
{ {
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>; tile_window.Store(dstr_tensor);
using BottomTensorView = remove_cvref_t<BottomTensorView_>;
using WindowLengths = remove_cvref_t<WindowLengths_>;
using TileDstr = remove_cvref_t<TileDistribution_>;
using TileWindow = TileWindowWithStaticDistribution<BottomTensorView, WindowLengths, TileDstr>;
static_assert(is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
static_assert(TileWindow::HasStaticTileDistribution(), "wrong!");
constexpr auto tile_dstr = TileDstr{};
constexpr auto thread_tensor_lengths_ys =
to_sequence(tile_dstr.GetYs2DDescriptor().GetLengths());
constexpr index_t NDimP = TileDstr::GetNumOfDimensionP();
constexpr index_t NDimY = TileDstr::GetNumOfDimensionY();
constexpr auto tmp = []() {
const auto [ys_vector_lengths, ys_vector_strides] =
TileWindow::GetWindowAdaptorYsSafeVectorLengthStrides();
index_t VectorDimY = 0;
index_t ScalarPerVector = 1;
for(index_t i = 0; i < NDimY; ++i)
{
if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector)
{
ScalarPerVector = ys_vector_lengths[i];
VectorDimY = i;
}
}
return make_tuple(VectorDimY, ScalarPerVector);
}();
constexpr index_t VectorDimY = tmp.template At<0>();
constexpr index_t ScalarPerVector = tmp.template At<1>();
// FIXME:
using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
constexpr auto scalars_per_access_arr = generate_array(
[&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, Number<NDimY>{});
constexpr auto scalars_per_access = TO_SEQUENCE(scalars_per_access_arr, NDimY);
using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
using vector_t = typename vector_type_t::type;
using SFC_Ys = SpaceFillingCurve<decltype(thread_tensor_lengths_ys),
DimAccessOrder,
decltype(scalars_per_access)>;
constexpr index_t num_access = SFC_Ys::GetNumOfAccess();
static_assert(num_access > 0, "wrong! num_access should be larger than 0");
// loop over thread tensor space [y0, y1, ...]
static_for<0, num_access, 1>{}([&](auto iAccess) {
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::GetIndex(iAccess);
// read from distributed tensor
vector_type_t vec;
static_for<0, ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
[&](auto jj) {
return jj == VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
},
Number<NDimY>{});
constexpr index_t d = tile_dstr.GetYs2DDescriptor().CalculateOffset(idx_ys);
vec.template AsType<DataType>()(j) = dstr_tensor.GetThreadBuffer().template At<d>();
});
const vector_t vec_value = vec.template AsType<vector_t>().template At<0>();
// write into bottom tensor
tile_window.GetBottomTensorView().template SetVectorizedElements<vector_t>(
tile_window.GetBottomTensorThreadCoordinate(), vec_value);
// move thread coordinate
if constexpr(iAccess.value != num_access - 1)
{
constexpr auto idx_diff_ys = SFC_Ys::GetForwardStep(iAccess);
constexpr auto idx_diff_ps_ys = container_concat(Array<index_t, NDimP>{0}, idx_diff_ys);
tile_window.MoveWindowAdaptorAndBottomTensorThreadCoordinate(idx_diff_ps_ys);
}
});
// move thread coordinate back to origin
{
constexpr auto idx_diff_ys = SFC_Ys::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
constexpr auto idx_diff_ps_ys = container_concat(Array<index_t, NDimP>{0}, idx_diff_ys);
tile_window.MoveWindowAdaptorAndBottomTensorThreadCoordinate(idx_diff_ps_ys);
}
} }
} // namespace tile_program } // namespace tile_program
......
...@@ -23,11 +23,8 @@ __device__ void ...@@ -23,11 +23,8 @@ __device__ void
store_tile(TileWindowWithStaticLengths<BottomTensorView_, WindowLengths_>& tile_window_tmp, store_tile(TileWindowWithStaticLengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
const StaticDistributedTensor<DataType_, TileDistribution_>& dstr_tensor) const StaticDistributedTensor<DataType_, TileDistribution_>& dstr_tensor)
{ {
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>; using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
using BottomTensorView = remove_cvref_t<BottomTensorView_>; using TileDstr = remove_cvref_t<TileDistribution_>;
using WindowLengths = remove_cvref_t<WindowLengths_>;
using TileDstr = remove_cvref_t<TileDistribution_>;
using TileWindow = TileWindowWithStaticDistribution<BottomTensorView, WindowLengths, TileDstr>;
static_assert(is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!"); static_assert(is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
...@@ -38,98 +35,7 @@ store_tile(TileWindowWithStaticLengths<BottomTensorView_, WindowLengths_>& tile_ ...@@ -38,98 +35,7 @@ store_tile(TileWindowWithStaticLengths<BottomTensorView_, WindowLengths_>& tile_
tile_window_tmp.GetWindowOrigin(), tile_window_tmp.GetWindowOrigin(),
tile_dstr); tile_dstr);
constexpr auto thread_tensor_lengths_ys = tile_window.Store(dstr_tensor);
to_sequence(tile_dstr.GetYs2DDescriptor().GetLengths());
constexpr index_t NDimP = TileDstr::GetNumOfDimensionP();
constexpr index_t NDimY = TileDstr::GetNumOfDimensionY();
constexpr auto tmp = []() {
const auto [ys_vector_lengths, ys_vector_strides] =
TileWindow::GetWindowAdaptorYsSafeVectorLengthStrides();
index_t VectorDimY = 0;
index_t ScalarPerVector = 1;
for(index_t i = 0; i < NDimY; ++i)
{
if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector)
{
ScalarPerVector = ys_vector_lengths[i];
VectorDimY = i;
}
}
return make_tuple(VectorDimY, ScalarPerVector);
}();
constexpr index_t VectorDimY = tmp.template At<0>();
constexpr index_t ScalarPerVector = tmp.template At<1>();
// FIXME:
using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
constexpr auto scalars_per_access_arr = generate_array(
[&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, Number<NDimY>{});
constexpr auto scalars_per_access = TO_SEQUENCE(scalars_per_access_arr, NDimY);
using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
using vector_t = typename vector_type_t::type;
using SFC_Ys = SpaceFillingCurve<decltype(thread_tensor_lengths_ys),
DimAccessOrder,
decltype(scalars_per_access)>;
constexpr index_t num_access = SFC_Ys::GetNumOfAccess();
static_assert(num_access > 0, "wrong! num_access should be larger than 0");
// loop over thread tensor space [y0, y1, ...]
static_for<0, num_access, 1>{}([&](auto iAccess) {
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::GetIndex(iAccess);
// read from distributed tensor
vector_type_t vec;
static_for<0, ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
[&](auto jj) {
return jj == VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
},
Number<NDimY>{});
constexpr index_t d = tile_dstr.GetYs2DDescriptor().CalculateOffset(idx_ys);
vec.template AsType<DataType>()(j) = dstr_tensor.GetThreadBuffer().template At<d>();
});
const vector_t vec_value = vec.template AsType<vector_t>().template At<0>();
// write into bottom tensor
tile_window.GetBottomTensorView().template SetVectorizedElements<vector_t>(
tile_window.GetBottomTensorThreadCoordinate(), vec_value);
// move thread coordinate
if constexpr(iAccess.value != num_access - 1)
{
constexpr auto idx_diff_ys = SFC_Ys::GetForwardStep(iAccess);
constexpr auto idx_diff_ps_ys = container_concat(Array<index_t, NDimP>{0}, idx_diff_ys);
tile_window.MoveWindowAdaptorAndBottomTensorThreadCoordinate(idx_diff_ps_ys);
}
});
// move thread coordinate back to origin
{
constexpr auto idx_diff_ys = SFC_Ys::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
constexpr auto idx_diff_ps_ys = container_concat(Array<index_t, NDimP>{0}, idx_diff_ys);
tile_window.MoveWindowAdaptorAndBottomTensorThreadCoordinate(idx_diff_ps_ys);
}
} }
} // namespace tile_program } // namespace tile_program
......
...@@ -46,6 +46,9 @@ struct TileWindowWithStaticLengths ...@@ -46,6 +46,9 @@ struct TileWindowWithStaticLengths
__device__ constexpr auto GetWindowOrigin() const { return window_origin_; } __device__ constexpr auto GetWindowOrigin() const { return window_origin_; }
// move window-origin
__device__ void Move(const BottomTensorIndex& step) { window_origin_ += step; }
// this is the bottom tensor view // this is the bottom tensor view
// [x0', x1', ...] ==> [offset] // [x0', x1', ...] ==> [offset]
BottomTensorView bottom_tensor_view_; BottomTensorView bottom_tensor_view_;
...@@ -73,10 +76,10 @@ make_tile_window(const TensorView_& tensor_view, ...@@ -73,10 +76,10 @@ make_tile_window(const TensorView_& tensor_view,
template <typename TensorView_, typename WindowLengths_> template <typename TensorView_, typename WindowLengths_>
__device__ void move_tile_window( __device__ void move_tile_window(
TileWindowWithStaticLengths<TensorView_, WindowLengths_>& window, TileWindowWithStaticLengths<TensorView_, WindowLengths_>& window,
const MultiIndex<TileWindowWithStaticLengths<TensorView_, WindowLengths_>::GetNumOfDimension()>& const typename TileWindowWithStaticLengths<TensorView_, WindowLengths_>::BottomTensorIndex&
step) step)
{ {
window.window_origin_ += step; window.Move(step);
} }
} // namespace tile_program } // namespace tile_program
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment