"docs/vscode:/vscode.git/clone" did not exist on "999044596aa5118308753a3b77b97403378f0b3d"
Commit 39821a90 authored by Chao Liu's avatar Chao Liu
Browse files

refactor, clean up

parent 01e94729
...@@ -232,13 +232,13 @@ struct DummyDynamicTransform_1 ...@@ -232,13 +232,13 @@ struct DummyDynamicTransform_1
auto in_gemmk_gemmn_coord = make_dynamic_tensor_coordinate(in_gemmk_gemmn_global_desc, idx); auto in_gemmk_gemmn_coord = make_dynamic_tensor_coordinate(in_gemmk_gemmn_global_desc, idx);
const auto in_gemmk_gemmn_coord_step = const auto in_gemmk_gemmn_coord_iterator = make_dynamic_tensor_coordinate_iterator(
make_dynamic_tensor_coordinate_step(in_gemmk_gemmn_global_desc, make_multi_index(1, 0)); in_gemmk_gemmn_global_desc, make_multi_index(1, 0));
for(index_t iter = 0; iter < niter; ++iter) for(index_t iter = 0; iter < niter; ++iter)
{ {
move_dynamic_tensor_coordinate( move_dynamic_tensor_coordinate(
in_gemmk_gemmn_global_desc, in_gemmk_gemmn_coord, in_gemmk_gemmn_coord_step); in_gemmk_gemmn_global_desc, in_gemmk_gemmn_coord, in_gemmk_gemmn_coord_iterator);
// write // write
float value = 1; float value = 1;
...@@ -352,12 +352,12 @@ struct DummyDynamicTransform_1 ...@@ -352,12 +352,12 @@ struct DummyDynamicTransform_1
auto in_coord = make_dynamic_tensor_coordinate(in_n_c_hip_wip_global_desc, idx); auto in_coord = make_dynamic_tensor_coordinate(in_n_c_hip_wip_global_desc, idx);
const auto in_coord_step = make_dynamic_tensor_coordinate_step( const auto in_coord_iterator = make_dynamic_tensor_coordinate_iterator(
in_n_c_hip_wip_global_desc, make_multi_index(1, 0, 0, 0)); in_n_c_hip_wip_global_desc, make_multi_index(1, 0, 0, 0));
for(index_t iter = 0; iter < niter; ++iter) for(index_t iter = 0; iter < niter; ++iter)
{ {
move_dynamic_tensor_coordinate(in_n_c_hip_wip_global_desc, in_coord, in_coord_step); move_dynamic_tensor_coordinate(in_n_c_hip_wip_global_desc, in_coord, in_coord_iterator);
// write // write
float value = 1; float value = 1;
...@@ -430,21 +430,24 @@ struct DummyDynamicTransform_fwd_v4r4 ...@@ -430,21 +430,24 @@ struct DummyDynamicTransform_fwd_v4r4
auto in_gemmk_gemmn_gemmkpack_coord = auto in_gemmk_gemmn_gemmkpack_coord =
make_dynamic_tensor_coordinate(in_gemmk_gemmn_gemmkpack_global_desc, idx); make_dynamic_tensor_coordinate(in_gemmk_gemmn_gemmkpack_global_desc, idx);
const auto in_gemmk_gemmn_gemmkpack_coord_step_0_0_1 = make_dynamic_tensor_coordinate_step( const auto in_gemmk_gemmn_gemmkpack_coord_iterator_0_0_1 =
in_gemmk_gemmn_gemmkpack_global_desc, make_multi_index(0, 0, 1)); make_dynamic_tensor_coordinate_iterator(in_gemmk_gemmn_gemmkpack_global_desc,
make_multi_index(0, 0, 1));
const auto in_gemmk_gemmn_gemmkpack_coord_step_0_1_0 = make_dynamic_tensor_coordinate_step( const auto in_gemmk_gemmn_gemmkpack_coord_iterator_0_1_0 =
in_gemmk_gemmn_gemmkpack_global_desc, make_multi_index(0, 1, 0)); make_dynamic_tensor_coordinate_iterator(in_gemmk_gemmn_gemmkpack_global_desc,
make_multi_index(0, 1, 0));
const auto in_gemmk_gemmn_gemmkpack_coord_step_1_0_0 = make_dynamic_tensor_coordinate_step( const auto in_gemmk_gemmn_gemmkpack_coord_iterator_1_0_0 =
in_gemmk_gemmn_gemmkpack_global_desc, make_multi_index(1, 0, 0)); make_dynamic_tensor_coordinate_iterator(in_gemmk_gemmn_gemmkpack_global_desc,
make_multi_index(1, 0, 0));
// move (0, 0, 1) // move (0, 0, 1)
for(index_t iter = 0; iter < niter; ++iter) for(index_t iter = 0; iter < niter; ++iter)
{ {
move_dynamic_tensor_coordinate(in_gemmk_gemmn_gemmkpack_global_desc, move_dynamic_tensor_coordinate(in_gemmk_gemmn_gemmkpack_global_desc,
in_gemmk_gemmn_gemmkpack_coord, in_gemmk_gemmn_gemmkpack_coord,
in_gemmk_gemmn_gemmkpack_coord_step_0_0_1); in_gemmk_gemmn_gemmkpack_coord_iterator_0_0_1);
// write // write
float value = 1; float value = 1;
...@@ -476,7 +479,7 @@ struct DummyDynamicTransform_fwd_v4r4 ...@@ -476,7 +479,7 @@ struct DummyDynamicTransform_fwd_v4r4
{ {
move_dynamic_tensor_coordinate(in_gemmk_gemmn_gemmkpack_global_desc, move_dynamic_tensor_coordinate(in_gemmk_gemmn_gemmkpack_global_desc,
in_gemmk_gemmn_gemmkpack_coord, in_gemmk_gemmn_gemmkpack_coord,
in_gemmk_gemmn_gemmkpack_coord_step_0_1_0); in_gemmk_gemmn_gemmkpack_coord_iterator_0_1_0);
// write // write
float value = 1; float value = 1;
...@@ -508,7 +511,7 @@ struct DummyDynamicTransform_fwd_v4r4 ...@@ -508,7 +511,7 @@ struct DummyDynamicTransform_fwd_v4r4
{ {
move_dynamic_tensor_coordinate(in_gemmk_gemmn_gemmkpack_global_desc, move_dynamic_tensor_coordinate(in_gemmk_gemmn_gemmkpack_global_desc,
in_gemmk_gemmn_gemmkpack_coord, in_gemmk_gemmn_gemmkpack_coord,
in_gemmk_gemmn_gemmkpack_coord_step_1_0_0); in_gemmk_gemmn_gemmkpack_coord_iterator_1_0_0);
// write // write
float value = 1; float value = 1;
......
...@@ -9,12 +9,8 @@ namespace ck { ...@@ -9,12 +9,8 @@ namespace ck {
template <index_t NDimHidden, typename VisibleDimensionIds> template <index_t NDimHidden, typename VisibleDimensionIds>
struct DynamicTensorCoordinate; struct DynamicTensorCoordinate;
#if 0 // hack template <index_t NTransform, index_t NDimVisible, typename UpdateLowerIndexHack>
template <index_t NTransform, index_t NDimVisible> struct DynamicTensorCoordinateIterator;
#else
template <index_t NTransform, index_t NDimVisible, typename HackCalculateLowerIndexDiff>
#endif
struct DynamicTensorCoordinateStep;
// Transforms: Tuple<transforms...> // Transforms: Tuple<transforms...>
// LowerDimensionIdss : Tuple<Sequence<...>, ...> // LowerDimensionIdss : Tuple<Sequence<...>, ...>
...@@ -193,18 +189,14 @@ struct DynamicTensorCoordinate ...@@ -193,18 +189,14 @@ struct DynamicTensorCoordinate
HiddenIndex idx_hidden_; HiddenIndex idx_hidden_;
}; };
#if 0 // hack template <index_t NTransform, index_t NDimVisible, typename UpdateLowerIndexHack>
template <index_t NTransform, index_t NDimVisible> struct DynamicTensorCoordinateIterator
#else
template <index_t NTransform, index_t NDimVisible, typename HackCalculateLowerIndexDiff>
#endif
struct DynamicTensorCoordinateStep
{ {
// TODO make these private // TODO make these private
using VisibleIndex = MultiIndex<NDimVisible>; using VisibleIndex = MultiIndex<NDimVisible>;
public: public:
__host__ __device__ explicit constexpr DynamicTensorCoordinateStep( __host__ __device__ explicit constexpr DynamicTensorCoordinateIterator(
const VisibleIndex& idx_diff_visible, const MultiIndex<NTransform>& do_transforms) const VisibleIndex& idx_diff_visible, const MultiIndex<NTransform>& do_transforms)
: idx_diff_visible_{idx_diff_visible}, do_transforms_{do_transforms} : idx_diff_visible_{idx_diff_visible}, do_transforms_{do_transforms}
{ {
...@@ -221,10 +213,8 @@ struct DynamicTensorCoordinateStep ...@@ -221,10 +213,8 @@ struct DynamicTensorCoordinateStep
const VisibleIndex idx_diff_visible_; const VisibleIndex idx_diff_visible_;
const MultiIndex<NTransform> do_transforms_; const MultiIndex<NTransform> do_transforms_;
#if 1 // hack // HACK: control UpdateLowerIndex()
// HACK: control CalculateLowerIndexDiff for DynamicMerge using ing hack static constexpr UpdateLowerIndexHack update_lower_index_hack_;
static constexpr HackCalculateLowerIndexDiff hack_calculate_lower_index_diff_;
#endif
}; };
// TODO: How to fix this? It uses an struct instead of lambda because lambda // TODO: How to fix this? It uses an struct instead of lambda because lambda
...@@ -350,9 +340,11 @@ __host__ __device__ constexpr auto make_dynamic_tensor_coordinate(const TensorDe ...@@ -350,9 +340,11 @@ __host__ __device__ constexpr auto make_dynamic_tensor_coordinate(const TensorDe
return DynamicTensorCoordinate<ndim_hidden, decltype(visible_dim_ids)>{idx_hidden}; return DynamicTensorCoordinate<ndim_hidden, decltype(visible_dim_ids)>{idx_hidden};
} }
template <typename TensorDesc, typename VisibleIndex> // UpdateLowerIndexHack: Sequence<...>
__host__ __device__ constexpr auto // HACK: control UpdateLowerIndex
make_dynamic_tensor_coordinate_step(const TensorDesc&, const VisibleIndex& idx_diff_visible) template <typename TensorDesc, typename VisibleIndex, typename UpdateLowerIndexHack>
__host__ __device__ constexpr auto make_dynamic_tensor_coordinate_iterator(
const TensorDesc&, const VisibleIndex& idx_diff_visible, UpdateLowerIndexHack)
{ {
static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(), static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(),
"wrong! # of dimension inconsistent"); "wrong! # of dimension inconsistent");
...@@ -362,6 +354,8 @@ make_dynamic_tensor_coordinate_step(const TensorDesc&, const VisibleIndex& idx_d ...@@ -362,6 +354,8 @@ make_dynamic_tensor_coordinate_step(const TensorDesc&, const VisibleIndex& idx_d
constexpr index_t ndim_visible = TensorDesc::GetNumOfVisibleDimension(); constexpr index_t ndim_visible = TensorDesc::GetNumOfVisibleDimension();
constexpr auto visible_dim_ids = TensorDesc::GetVisibleDimensionIds(); constexpr auto visible_dim_ids = TensorDesc::GetVisibleDimensionIds();
static_assert(UpdateLowerIndexHack::Size() == ntransform, "wrong!");
// use index_t for boolean type // use index_t for boolean type
auto do_transforms = make_zero_multi_index<ntransform>(); auto do_transforms = make_zero_multi_index<ntransform>();
auto is_non_zero_diff = make_zero_multi_index<ndim_hidden>(); auto is_non_zero_diff = make_zero_multi_index<ndim_hidden>();
...@@ -397,78 +391,23 @@ make_dynamic_tensor_coordinate_step(const TensorDesc&, const VisibleIndex& idx_d ...@@ -397,78 +391,23 @@ make_dynamic_tensor_coordinate_step(const TensorDesc&, const VisibleIndex& idx_d
set_container_subset(is_non_zero_diff, dims_low, non_zero_diff_pick_low); set_container_subset(is_non_zero_diff, dims_low, non_zero_diff_pick_low);
}); });
#if 0 // hack return DynamicTensorCoordinateIterator<ntransform, ndim_visible, UpdateLowerIndexHack>{
return DynamicTensorCoordinateStep<ntransform, ndim_visible>{idx_diff_visible, do_transforms};
#else
return DynamicTensorCoordinateStep<ntransform,
ndim_visible,
typename uniform_sequence_gen<ntransform, 0>::type>{
idx_diff_visible, do_transforms}; idx_diff_visible, do_transforms};
#endif
} }
#if 0 // hack
template <typename TensorDesc, typename VisibleIndex> template <typename TensorDesc, typename VisibleIndex>
#else __host__ __device__ constexpr auto
// HACK: control CalculateLowerIndexDiff for DynamicMerge using ing hack make_dynamic_tensor_coordinate_iterator(const TensorDesc&, const VisibleIndex& idx_diff_visible)
template <typename TensorDesc, typename VisibleIndex, typename HackCalculateLowerIndexDiff>
#endif
__host__ __device__ constexpr auto make_dynamic_tensor_coordinate_step_hack(
const TensorDesc&, const VisibleIndex& idx_diff_visible, HackCalculateLowerIndexDiff)
{ {
static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(), constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
"wrong! # of dimension inconsistent");
constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension();
constexpr index_t ndim_visible = TensorDesc::GetNumOfVisibleDimension();
constexpr auto visible_dim_ids = TensorDesc::GetVisibleDimensionIds();
static_assert(HackCalculateLowerIndexDiff::Size() == ntransform, "wrong!");
// use index_t for boolean type
auto do_transforms = make_zero_multi_index<ntransform>();
auto is_non_zero_diff = make_zero_multi_index<ndim_hidden>();
// decide do_transform by checkout non-zero index diff components
MultiIndex<VisibleIndex::Size()> non_zero_diff_pick_visible;
static_for<0, ndim_visible, 1>{}(
[&](auto i) { non_zero_diff_pick_visible(i) = (idx_diff_visible[i] != 0); });
set_container_subset(is_non_zero_diff, visible_dim_ids, non_zero_diff_pick_visible);
static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran);
constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran);
const auto non_zero_diff_pick_up = get_container_subset(is_non_zero_diff, dims_up);
MultiIndex<dims_low.Size()> non_zero_diff_pick_low;
// if any of upper index diff components is non-zero, then
// 1) Need to do this transform
// 2) all components of lower index diff will assume to be non-zero and need to be
// computed
const bool idx_diff_up_has_non_zero = container_reduce(
non_zero_diff_pick_up, [](auto a, auto b) constexpr { return a or b; }, false);
do_transforms(itran) = idx_diff_up_has_non_zero;
static_for<0, dims_low.Size(), 1>{}(
[&](auto i) { non_zero_diff_pick_low(i) = idx_diff_up_has_non_zero; });
set_container_subset(is_non_zero_diff, dims_low, non_zero_diff_pick_low);
});
return DynamicTensorCoordinateStep<ntransform, ndim_visible, HackCalculateLowerIndexDiff>{ return make_dynamic_tensor_coordinate_iterator(
idx_diff_visible, do_transforms}; TensorDesc{}, idx_diff_visible, typename uniform_sequence_gen<ntransform, 0>::type{});
} }
template <typename TensorDesc, typename TensorCoord, typename TensorCoordStep> template <typename TensorDesc, typename TensorCoord, typename TensorCoordIterator>
__host__ __device__ constexpr void move_dynamic_tensor_coordinate(const TensorDesc& tensor_desc, __host__ __device__ constexpr void move_dynamic_tensor_coordinate(
TensorCoord& coord, const TensorDesc& tensor_desc, TensorCoord& coord, const TensorCoordIterator& coord_iterator)
const TensorCoordStep& coord_step)
{ {
constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension(); constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension();
constexpr index_t ndim_visible = TensorDesc::GetNumOfVisibleDimension(); constexpr index_t ndim_visible = TensorDesc::GetNumOfVisibleDimension();
...@@ -480,8 +419,9 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate(const TensorDe ...@@ -480,8 +419,9 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate(const TensorDe
auto idx_diff_hidden = make_zero_multi_index<ndim_hidden>(); auto idx_diff_hidden = make_zero_multi_index<ndim_hidden>();
// initialize visible index diff // initialize visible index diff
set_container_subset( set_container_subset(idx_diff_hidden,
idx_diff_hidden, TensorDesc::GetVisibleDimensionIds(), coord_step.GetVisibleIndexDiff()); TensorDesc::GetVisibleDimensionIds(),
coord_iterator.GetVisibleIndexDiff());
// this is what needs to be updated // this is what needs to be updated
auto& idx_hidden = coord.GetHiddenIndex(); auto& idx_hidden = coord.GetHiddenIndex();
...@@ -490,13 +430,13 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate(const TensorDe ...@@ -490,13 +430,13 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate(const TensorDe
auto idx_hidden_pick_visible = auto idx_hidden_pick_visible =
get_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds()); get_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds());
idx_hidden_pick_visible += coord_step.GetIndexDiff(); idx_hidden_pick_visible += coord_iterator.GetIndexDiff();
set_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds(), idx_hidden_pick_visible); set_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds(), idx_hidden_pick_visible);
// update rest of hidden index // update rest of hidden index
static_for<ntransform - 1, -1, -1>{}([&](auto itran) { static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
if(coord_step.do_transforms_[itran]) if(coord_iterator.do_transforms_[itran])
{ {
const auto& tran = tensor_desc.GetTransforms().At(itran); const auto& tran = tensor_desc.GetTransforms().At(itran);
constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran); constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran);
...@@ -509,9 +449,7 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate(const TensorDe ...@@ -509,9 +449,7 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate(const TensorDe
MultiIndex<dims_low.Size()> idx_diff_low; MultiIndex<dims_low.Size()> idx_diff_low;
// HACK: control UpdateLowerIndex for DynamicMerge using hack // HACK: control UpdateLowerIndex for DynamicMerge using hack
// TODO remove hack constexpr index_t Hack = decltype(coord_iterator.update_lower_index_hack_)::At(itran);
constexpr index_t Hack =
decltype(coord_step.hack_calculate_lower_index_diff_)::At(itran);
tran.UpdateLowerIndex(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number<Hack>{}); tran.UpdateLowerIndex(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number<Hack>{});
...@@ -579,7 +517,7 @@ using DynamicTensorCoordinate_t = decltype(make_dynamic_tensor_coordinate( ...@@ -579,7 +517,7 @@ using DynamicTensorCoordinate_t = decltype(make_dynamic_tensor_coordinate(
TensorDesc{}, MultiIndex<remove_cv_t<remove_reference_t<TensorDesc>>::GetNumOfDimension()>{})); TensorDesc{}, MultiIndex<remove_cv_t<remove_reference_t<TensorDesc>>::GetNumOfDimension()>{}));
template <typename TensorDesc> template <typename TensorDesc>
using DynamicTensorCoordinateStep_t = decltype(make_dynamic_tensor_coordinate_step( using DynamicTensorCoordinateIterator_t = decltype(make_dynamic_tensor_coordinate_iterator(
TensorDesc{}, MultiIndex<remove_cv_t<remove_reference_t<TensorDesc>>::GetNumOfDimension()>{})); TensorDesc{}, MultiIndex<remove_cv_t<remove_reference_t<TensorDesc>>::GetNumOfDimension()>{}));
} // namespace ck } // namespace ck
......
...@@ -87,21 +87,15 @@ struct BlockwiseDynamicTensorSliceTransfer_v4 ...@@ -87,21 +87,15 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
return thread_cluster_id * ThreadSliceLengths{}; return thread_cluster_id * ThreadSliceLengths{};
} }
__device__ void RunRead(const SrcDesc& src_desc, const SrcData* p_src) template <typename SrcIteratorHacks>
__device__ void RunRead(const SrcDesc& src_desc,
const SrcData* p_src,
const SrcIteratorHacks& src_iterator_hacks)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.RunRead(src_desc, p_src); threadwise_transfer_.RunRead(src_desc, p_src, src_iterator_hacks);
}
}
__device__ void RunRead_hack(const SrcDesc& src_desc, const SrcData* p_src)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunRead_hack(src_desc, p_src);
} }
} }
...@@ -123,12 +117,18 @@ struct BlockwiseDynamicTensorSliceTransfer_v4 ...@@ -123,12 +117,18 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
} }
} }
__device__ void MoveSrcSliceWindow_hack(const SrcDesc& src_desc, const Index& step) // SrcMoveSliceWindowIteratorHack to control index calculation move slice window
template <typename SrcMoveSliceWindowIteratorHack>
__device__ void
MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& step,
const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.MoveSrcSliceWindow_hack(src_desc, step); threadwise_transfer_.MoveSrcSliceWindow(
src_desc, step, src_move_slice_window_iterator_hack);
} }
} }
......
...@@ -259,10 +259,47 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -259,10 +259,47 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0);
// hack to control index calculation when iterating over a_k_m_global tensor
constexpr auto a_k_m_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
constexpr auto a_k_m_global_reset_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
// hack to control index calculation when iterating over b_k_n_global tensor
#if 0
// for padded input
constexpr auto b_k_n_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2>{}));
constexpr auto b_k_n_global_move_slice_window_iterator_hack =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{};
#elif 0
// for non-padded input
constexpr auto b_k_n_global_iterator_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0, 1, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 1>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 2, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 2>{}));
constexpr auto b_k_n_global_move_slice_window_iterator_hack =
Sequence<0, 0, 0, 0, 0, 1, 2>{};
#elif 1
// for 1x1 case
constexpr auto b_k_n_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 1, 0>{}, Sequence<0, 0, 1>{}),
make_tuple(Sequence<0, 2, 0>{}, Sequence<0, 0, 2>{}));
constexpr auto b_k_n_global_move_slice_window_iterator_hack = Sequence<0, 1, 2>{};
#endif
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global); a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead_hack(b_k_n_global_desc, p_b_global); b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double); a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_double); b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_double);
...@@ -284,14 +321,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -284,14 +321,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
{ {
// even iteration // even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow_hack(b_k_n_global_desc, b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc,
b_block_slice_copy_step); b_block_slice_copy_step,
b_k_n_global_move_slice_window_iterator_hack);
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global); a_blockwise_copy.RunRead(
b_blockwise_copy.RunRead_hack(b_k_n_global_desc, p_b_global); a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_even, p_b_block_even, p_c_thread); blockwise_gemm.Run(p_a_block_even, p_b_block_even, p_c_thread);
...@@ -302,14 +342,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -302,14 +342,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
// odd iteration // odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow_hack(b_k_n_global_desc, b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc,
b_block_slice_copy_step); b_block_slice_copy_step,
b_k_n_global_move_slice_window_iterator_hack);
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global); a_blockwise_copy.RunRead(
b_blockwise_copy.RunRead_hack(b_k_n_global_desc, p_b_global); a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_odd, p_b_block_odd, p_c_thread); blockwise_gemm.Run(p_a_block_odd, p_b_block_odd, p_c_thread);
...@@ -326,13 +369,15 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -326,13 +369,15 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{ {
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow_hack(b_k_n_global_desc, b_block_slice_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc,
b_block_slice_copy_step,
b_k_n_global_move_slice_window_iterator_hack);
__syncthreads(); __syncthreads();
// LDS double buffer: load last data from device mem // LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global); a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead_hack(b_k_n_global_desc, p_b_global); b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread); blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
...@@ -383,6 +428,18 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -383,6 +428,18 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
const index_t n_thread_data_on_global = const index_t n_thread_data_on_global =
n_block_data_on_global + c_thread_mtx_on_block.col; n_block_data_on_global + c_thread_mtx_on_block.col;
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
// hack for NKHW format
constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
ThreadwiseDynamicTensorSliceTransfer_v1r3< ThreadwiseDynamicTensorSliceTransfer_v1r3<
AccFloat, AccFloat,
Float, Float,
...@@ -402,7 +459,10 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -402,7 +459,10 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
m_thread_data_on_global % M1, m_thread_data_on_global % M1,
n_thread_data_on_global / N1, n_thread_data_on_global / N1,
n_thread_data_on_global % N1)) n_thread_data_on_global % N1))
.Run_hack(p_c_thread, c_m0_m1_n0_n1_global_desc, p_c_global); .Run(p_c_thread,
c_m0_m1_n0_n1_global_desc,
p_c_global,
c_m0_m1_n0_n1_global_tensor_iterator_hacks);
} }
} }
...@@ -435,5 +495,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -435,5 +495,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailKBlockLoop>{});
} }
}; };
} // namespace ck } // namespace ck
#endif #endif
...@@ -55,11 +55,12 @@ void device_dummy_dynamic_transform(InDesc, ...@@ -55,11 +55,12 @@ void device_dummy_dynamic_transform(InDesc,
auto in_gemmk_gemmn_gemmkpack_coord = make_dynamic_tensor_coordinate( auto in_gemmk_gemmn_gemmkpack_coord = make_dynamic_tensor_coordinate(
in_gemmk_gemmn_gemmkpack_global_desc, make_multi_index(0, 0, 0)); in_gemmk_gemmn_gemmkpack_global_desc, make_multi_index(0, 0, 0));
const auto in_gemmk_gemmn_gemmkpack_coord_step_0_0_1 = make_dynamic_tensor_coordinate_step( const auto in_gemmk_gemmn_gemmkpack_coord_iterator_0_0_1 =
in_gemmk_gemmn_gemmkpack_global_desc, make_multi_index(0, 0, 1)); make_dynamic_tensor_coordinate_iterator(in_gemmk_gemmn_gemmkpack_global_desc,
make_multi_index(0, 0, 1));
print_array_v2("do_tansforms 0 0 1: ", print_array_v2("do_tansforms 0 0 1: ",
in_gemmk_gemmn_gemmkpack_coord_step_0_0_1.do_transforms_); in_gemmk_gemmn_gemmkpack_coord_iterator_0_0_1.do_transforms_);
for(index_t iter = 0; iter < 10; ++iter) for(index_t iter = 0; iter < 10; ++iter)
{ {
...@@ -71,7 +72,7 @@ void device_dummy_dynamic_transform(InDesc, ...@@ -71,7 +72,7 @@ void device_dummy_dynamic_transform(InDesc,
move_dynamic_tensor_coordinate(in_gemmk_gemmn_gemmkpack_global_desc, move_dynamic_tensor_coordinate(in_gemmk_gemmn_gemmkpack_global_desc,
in_gemmk_gemmn_gemmkpack_coord, in_gemmk_gemmn_gemmkpack_coord,
in_gemmk_gemmn_gemmkpack_coord_step_0_0_1); in_gemmk_gemmn_gemmkpack_coord_iterator_0_0_1);
} }
} }
...@@ -79,11 +80,12 @@ void device_dummy_dynamic_transform(InDesc, ...@@ -79,11 +80,12 @@ void device_dummy_dynamic_transform(InDesc,
auto in_gemmk_gemmn_gemmkpack_coord = make_dynamic_tensor_coordinate( auto in_gemmk_gemmn_gemmkpack_coord = make_dynamic_tensor_coordinate(
in_gemmk_gemmn_gemmkpack_global_desc, make_multi_index(0, 0, 0)); in_gemmk_gemmn_gemmkpack_global_desc, make_multi_index(0, 0, 0));
const auto in_gemmk_gemmn_gemmkpack_coord_step_0_1_0 = make_dynamic_tensor_coordinate_step( const auto in_gemmk_gemmn_gemmkpack_coord_iterator_0_1_0 =
in_gemmk_gemmn_gemmkpack_global_desc, make_multi_index(0, 1, 0)); make_dynamic_tensor_coordinate_iterator(in_gemmk_gemmn_gemmkpack_global_desc,
make_multi_index(0, 1, 0));
print_array_v2("do_tansforms 0 1 0: ", print_array_v2("do_tansforms 0 1 0: ",
in_gemmk_gemmn_gemmkpack_coord_step_0_1_0.do_transforms_); in_gemmk_gemmn_gemmkpack_coord_iterator_0_1_0.do_transforms_);
for(index_t iter = 0; iter < 10; ++iter) for(index_t iter = 0; iter < 10; ++iter)
{ {
...@@ -95,7 +97,7 @@ void device_dummy_dynamic_transform(InDesc, ...@@ -95,7 +97,7 @@ void device_dummy_dynamic_transform(InDesc,
move_dynamic_tensor_coordinate(in_gemmk_gemmn_gemmkpack_global_desc, move_dynamic_tensor_coordinate(in_gemmk_gemmn_gemmkpack_global_desc,
in_gemmk_gemmn_gemmkpack_coord, in_gemmk_gemmn_gemmkpack_coord,
in_gemmk_gemmn_gemmkpack_coord_step_0_1_0); in_gemmk_gemmn_gemmkpack_coord_iterator_0_1_0);
} }
} }
...@@ -103,11 +105,12 @@ void device_dummy_dynamic_transform(InDesc, ...@@ -103,11 +105,12 @@ void device_dummy_dynamic_transform(InDesc,
auto in_gemmk_gemmn_gemmkpack_coord = make_dynamic_tensor_coordinate( auto in_gemmk_gemmn_gemmkpack_coord = make_dynamic_tensor_coordinate(
in_gemmk_gemmn_gemmkpack_global_desc, make_multi_index(0, 0, 0)); in_gemmk_gemmn_gemmkpack_global_desc, make_multi_index(0, 0, 0));
const auto in_gemmk_gemmn_gemmkpack_coord_step_1_0_0 = make_dynamic_tensor_coordinate_step( const auto in_gemmk_gemmn_gemmkpack_coord_iterator_1_0_0 =
in_gemmk_gemmn_gemmkpack_global_desc, make_multi_index(1, 0, 0)); make_dynamic_tensor_coordinate_iterator(in_gemmk_gemmn_gemmkpack_global_desc,
make_multi_index(1, 0, 0));
print_array_v2("do_tansforms 1 0 0: ", print_array_v2("do_tansforms 1 0 0: ",
in_gemmk_gemmn_gemmkpack_coord_step_1_0_0.do_transforms_); in_gemmk_gemmn_gemmkpack_coord_iterator_1_0_0.do_transforms_);
for(index_t iter = 0; iter < 10; ++iter) for(index_t iter = 0; iter < 10; ++iter)
{ {
...@@ -119,7 +122,7 @@ void device_dummy_dynamic_transform(InDesc, ...@@ -119,7 +122,7 @@ void device_dummy_dynamic_transform(InDesc,
move_dynamic_tensor_coordinate(in_gemmk_gemmn_gemmkpack_global_desc, move_dynamic_tensor_coordinate(in_gemmk_gemmn_gemmkpack_global_desc,
in_gemmk_gemmn_gemmkpack_coord, in_gemmk_gemmn_gemmkpack_coord,
in_gemmk_gemmn_gemmkpack_coord_step_1_0_0); in_gemmk_gemmn_gemmkpack_coord_iterator_1_0_0);
} }
} }
......
...@@ -233,7 +233,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc ...@@ -233,7 +233,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
constexpr auto conv_driver = constexpr auto conv_driver =
#if 1 #if 0
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
#elif 0 #elif 0
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
......
...@@ -52,7 +52,7 @@ int main(int argc, char* argv[]) ...@@ -52,7 +52,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 1 #elif 0
// 3x3, 71x71 // 3x3, 71x71
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 192; constexpr index_t C = 192;
...@@ -67,7 +67,7 @@ int main(int argc, char* argv[]) ...@@ -67,7 +67,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<1, 1>; using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>;
#elif 0 #elif 1
// 1x1, 8x8 // 1x1, 8x8
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 1536; constexpr index_t C = 1536;
...@@ -592,7 +592,7 @@ int main(int argc, char* argv[]) ...@@ -592,7 +592,7 @@ int main(int argc, char* argv[])
LeftPads{}, LeftPads{},
RightPads{}, RightPads{},
nrepeat); nrepeat);
#elif 0 #elif 1
device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc, device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw, in_nchw,
wei_kcyx_desc, wei_kcyx_desc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment