Unverified Commit ccc4a1d3 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Merge pull request #8 from ROCmSoftwarePlatform/miopen_downstream_init_integration

parents 3b866461 16effa76
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
#define CK_TENSOR_ADAPTOR_HPP #define CK_TENSOR_ADAPTOR_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
namespace ck { namespace ck {
...@@ -454,9 +454,7 @@ __host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transf ...@@ -454,9 +454,7 @@ __host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transf
remove_cv_t<decltype(top_dim_hidden_ids)>>{transforms}; remove_cv_t<decltype(top_dim_hidden_ids)>>{transforms};
} }
template <typename X, template <typename X, typename... Xs, typename enable_if<sizeof...(Xs) >= 2, bool>::type = false>
typename... Xs,
typename std::enable_if<sizeof...(Xs) >= 2, bool>::type = false>
__host__ __device__ constexpr auto chain_tensor_adaptors(const X& x, const Xs&... xs) __host__ __device__ constexpr auto chain_tensor_adaptors(const X& x, const Xs&... xs)
{ {
return chain_tensor_adaptors(x, chain_tensor_adaptors(xs...)); return chain_tensor_adaptors(x, chain_tensor_adaptors(xs...));
......
#ifndef CK_DYNAMIC_TENSOR_DESCRIPTOR_HPP #ifndef CK_TENSOR_DESCRIPTOR_HPP
#define CK_DYNAMIC_TENSOR_DESCRIPTOR_HPP #define CK_TENSOR_DESCRIPTOR_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_multi_index_transform.hpp" #include "multi_index_transform.hpp"
namespace ck { namespace ck {
template <index_t NDimHidden, typename VisibleDimensionIds> template <index_t NDimHidden, typename VisibleDimensionIds>
struct DynamicTensorCoordinate; struct TensorCoordinate;
template <index_t NTransform, index_t NDimVisible, typename UpdateLowerIndexHack> template <index_t NTransform, index_t NDimVisible, typename UpdateLowerIndexHack>
struct DynamicTensorCoordinateIterator; struct TensorCoordinateStep;
// Transforms: Tuple<transforms...> // Transforms: Tuple<transforms...>
// LowerDimensionIdss : Tuple<Sequence<...>, ...> // LowerDimensionIdss : Tuple<Sequence<...>, ...>
...@@ -21,7 +21,7 @@ template <typename Transforms, ...@@ -21,7 +21,7 @@ template <typename Transforms,
typename UpperDimensionIdss, typename UpperDimensionIdss,
typename VisibleDimensionIds, typename VisibleDimensionIds,
typename ElementSpaceSize> typename ElementSpaceSize>
struct DynamicTensorDescriptor struct TensorDescriptor
{ {
// TODO make these private // TODO make these private
__host__ __device__ static constexpr index_t GetNumOfTransform() { return Transforms::Size(); } __host__ __device__ static constexpr index_t GetNumOfTransform() { return Transforms::Size(); }
...@@ -105,16 +105,16 @@ struct DynamicTensorDescriptor ...@@ -105,16 +105,16 @@ struct DynamicTensorDescriptor
using VisibleIndex = MultiIndex<ndim_visible_>; using VisibleIndex = MultiIndex<ndim_visible_>;
using HiddenIndex = MultiIndex<ndim_hidden_>; using HiddenIndex = MultiIndex<ndim_hidden_>;
using Coordinate = DynamicTensorCoordinate<ndim_hidden_, VisibleDimensionIds>; using Coordinate = TensorCoordinate<ndim_hidden_, VisibleDimensionIds>;
// may be index_t or Number<> // may be index_t or Number<>
using ElementSize = remove_cv_t<decltype(InitializeElementSize(Transforms{}))>; using ElementSize = remove_cv_t<decltype(InitializeElementSize(Transforms{}))>;
public: public:
__host__ __device__ constexpr DynamicTensorDescriptor() = default; __host__ __device__ constexpr TensorDescriptor() = default;
__host__ __device__ constexpr DynamicTensorDescriptor(const Transforms& transforms, __host__ __device__ constexpr TensorDescriptor(const Transforms& transforms,
ElementSpaceSize element_space_size) ElementSpaceSize element_space_size)
: transforms_{transforms}, : transforms_{transforms},
element_size_{InitializeElementSize(transforms)}, element_size_{InitializeElementSize(transforms)},
element_space_size_{element_space_size} element_space_size_{element_space_size}
...@@ -159,7 +159,7 @@ struct DynamicTensorDescriptor ...@@ -159,7 +159,7 @@ struct DynamicTensorDescriptor
{ {
static_assert(Idx::Size() == GetNumOfDimension(), "wrong! inconsistent # of dimension"); static_assert(Idx::Size() == GetNumOfDimension(), "wrong! inconsistent # of dimension");
return make_dynamic_tensor_coordinate(*this, idx).GetOffset(); return make_tensor_coordinate(*this, idx).GetOffset();
} }
// TODO make these private // TODO make these private
...@@ -196,7 +196,7 @@ struct DynamicTensorDescriptor ...@@ -196,7 +196,7 @@ struct DynamicTensorDescriptor
__host__ __device__ void Print() const __host__ __device__ void Print() const
{ {
printf("{"); printf("{");
printf("DynamicTensorDescriptor, "); printf("TensorDescriptor, ");
static_for<0, ntransform_, 1>{}([&](auto i) { static_for<0, ntransform_, 1>{}([&](auto i) {
printf("transforms: "); printf("transforms: ");
transforms_[i].Print(); transforms_[i].Print();
...@@ -217,7 +217,7 @@ struct DynamicTensorDescriptor ...@@ -217,7 +217,7 @@ struct DynamicTensorDescriptor
}; };
template <index_t NDimHidden, typename VisibleDimensionIds> template <index_t NDimHidden, typename VisibleDimensionIds>
struct DynamicTensorCoordinate struct TensorCoordinate
{ {
// TODO make these private // TODO make these private
static constexpr index_t ndim_visible_ = VisibleDimensionIds::Size(); static constexpr index_t ndim_visible_ = VisibleDimensionIds::Size();
...@@ -226,9 +226,9 @@ struct DynamicTensorCoordinate ...@@ -226,9 +226,9 @@ struct DynamicTensorCoordinate
using VisibleIndex = MultiIndex<ndim_visible_>; using VisibleIndex = MultiIndex<ndim_visible_>;
public: public:
__host__ __device__ constexpr DynamicTensorCoordinate() = default; __host__ __device__ constexpr TensorCoordinate() = default;
__host__ __device__ constexpr DynamicTensorCoordinate(const HiddenIndex& idx_hidden) __host__ __device__ constexpr TensorCoordinate(const HiddenIndex& idx_hidden)
: idx_hidden_{idx_hidden} : idx_hidden_{idx_hidden}
{ {
} }
...@@ -252,16 +252,16 @@ struct DynamicTensorCoordinate ...@@ -252,16 +252,16 @@ struct DynamicTensorCoordinate
}; };
template <index_t NTransform, index_t NDimVisible, typename UpdateLowerIndexHack> template <index_t NTransform, index_t NDimVisible, typename UpdateLowerIndexHack>
struct DynamicTensorCoordinateIterator struct TensorCoordinateStep
{ {
// TODO make these private // TODO make these private
using VisibleIndex = MultiIndex<NDimVisible>; using VisibleIndex = MultiIndex<NDimVisible>;
public: public:
__host__ __device__ constexpr DynamicTensorCoordinateIterator() = default; __host__ __device__ constexpr TensorCoordinateStep() = default;
__host__ __device__ constexpr DynamicTensorCoordinateIterator( __host__ __device__ constexpr TensorCoordinateStep(const VisibleIndex& idx_diff_visible,
const VisibleIndex& idx_diff_visible, const MultiIndex<NTransform>& do_transforms) const MultiIndex<NTransform>& do_transforms)
: idx_diff_visible_{idx_diff_visible}, do_transforms_{do_transforms} : idx_diff_visible_{idx_diff_visible}, do_transforms_{do_transforms}
{ {
} }
...@@ -283,7 +283,7 @@ struct DynamicTensorCoordinateIterator ...@@ -283,7 +283,7 @@ struct DynamicTensorCoordinateIterator
// TODO: How to fix this? It uses an struct instead of lambda because lambda // TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor, and to put it outside the scope where it is used // doesn't have constructor, and to put it outside the scope where it is used
// (transform_dynamic_tensor_descriptor) because template cannot be defined inside a function // (transform_tensor_descriptor) because template cannot be defined inside a function
// template // template
template <typename NewTransforms> template <typename NewTransforms>
struct lambda_get_up_dim_num struct lambda_get_up_dim_num
...@@ -301,10 +301,10 @@ template <typename OldTensorDescriptor, ...@@ -301,10 +301,10 @@ template <typename OldTensorDescriptor,
typename NewLowerDimensionOldVisibleIdss, typename NewLowerDimensionOldVisibleIdss,
typename NewUpperDimensionNewVisibleIdss> typename NewUpperDimensionNewVisibleIdss>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
transform_dynamic_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc, transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
const NewTransforms& new_transforms, const NewTransforms& new_transforms,
NewLowerDimensionOldVisibleIdss, NewLowerDimensionOldVisibleIdss,
NewUpperDimensionNewVisibleIdss) NewUpperDimensionNewVisibleIdss)
{ {
// sanity check // sanity check
{ {
...@@ -376,17 +376,17 @@ transform_dynamic_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc, ...@@ -376,17 +376,17 @@ transform_dynamic_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
const auto element_space_size = old_tensor_desc.GetElementSpaceSize(); const auto element_space_size = old_tensor_desc.GetElementSpaceSize();
return DynamicTensorDescriptor<remove_cv_t<decltype(all_transforms)>, return TensorDescriptor<remove_cv_t<decltype(all_transforms)>,
remove_cv_t<decltype(all_low_dim_hidden_idss)>, remove_cv_t<decltype(all_low_dim_hidden_idss)>,
remove_cv_t<decltype(all_up_dim_hidden_idss)>, remove_cv_t<decltype(all_up_dim_hidden_idss)>,
remove_cv_t<decltype(new_visible_dim_hidden_ids)>, remove_cv_t<decltype(new_visible_dim_hidden_ids)>,
remove_cv_t<decltype(element_space_size)>>{all_transforms, remove_cv_t<decltype(element_space_size)>>{all_transforms,
element_space_size}; element_space_size};
} }
template <typename TensorDesc, typename VisibleIndex> template <typename TensorDesc, typename VisibleIndex>
__host__ __device__ constexpr auto make_dynamic_tensor_coordinate(const TensorDesc& tensor_desc, __host__ __device__ constexpr auto make_tensor_coordinate(const TensorDesc& tensor_desc,
const VisibleIndex& idx_visible) const VisibleIndex& idx_visible)
{ {
static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(), static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(),
"wrong! # of dimension inconsistent"); "wrong! # of dimension inconsistent");
...@@ -416,14 +416,15 @@ __host__ __device__ constexpr auto make_dynamic_tensor_coordinate(const TensorDe ...@@ -416,14 +416,15 @@ __host__ __device__ constexpr auto make_dynamic_tensor_coordinate(const TensorDe
set_container_subset(idx_hidden, dims_low, idx_low); set_container_subset(idx_hidden, dims_low, idx_low);
}); });
return DynamicTensorCoordinate<ndim_hidden, decltype(visible_dim_ids)>{idx_hidden}; return TensorCoordinate<ndim_hidden, decltype(visible_dim_ids)>{idx_hidden};
} }
// UpdateLowerIndexHack: Sequence<...> // UpdateLowerIndexHack: Sequence<...>
// HACK: control UpdateLowerIndex // HACK: control UpdateLowerIndex
template <typename TensorDesc, typename VisibleIndex, typename UpdateLowerIndexHack> template <typename TensorDesc, typename VisibleIndex, typename UpdateLowerIndexHack>
__host__ __device__ constexpr auto make_dynamic_tensor_coordinate_iterator( __host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc&,
const TensorDesc&, const VisibleIndex& idx_diff_visible, UpdateLowerIndexHack) const VisibleIndex& idx_diff_visible,
UpdateLowerIndexHack)
{ {
static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(), static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(),
"wrong! # of dimension inconsistent"); "wrong! # of dimension inconsistent");
...@@ -470,23 +471,24 @@ __host__ __device__ constexpr auto make_dynamic_tensor_coordinate_iterator( ...@@ -470,23 +471,24 @@ __host__ __device__ constexpr auto make_dynamic_tensor_coordinate_iterator(
set_container_subset(is_non_zero_diff, dims_low, non_zero_diff_pick_low); set_container_subset(is_non_zero_diff, dims_low, non_zero_diff_pick_low);
}); });
return DynamicTensorCoordinateIterator<ntransform, ndim_visible, UpdateLowerIndexHack>{ return TensorCoordinateStep<ntransform, ndim_visible, UpdateLowerIndexHack>{idx_diff_visible,
idx_diff_visible, do_transforms}; do_transforms};
} }
template <typename TensorDesc, typename VisibleIndex> template <typename TensorDesc, typename VisibleIndex>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc&,
make_dynamic_tensor_coordinate_iterator(const TensorDesc&, const VisibleIndex& idx_diff_visible) const VisibleIndex& idx_diff_visible)
{ {
constexpr index_t ntransform = TensorDesc::GetNumOfTransform(); constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
return make_dynamic_tensor_coordinate_iterator( return make_tensor_coordinate_step(
TensorDesc{}, idx_diff_visible, typename uniform_sequence_gen<ntransform, 0>::type{}); TensorDesc{}, idx_diff_visible, typename uniform_sequence_gen<ntransform, 0>::type{});
} }
template <typename TensorDesc, typename TensorCoord, typename TensorCoordIterator> template <typename TensorDesc, typename TensorCoord, typename TensorCoordStep>
__host__ __device__ constexpr void move_dynamic_tensor_coordinate( __host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc& tensor_desc,
const TensorDesc& tensor_desc, TensorCoord& coord, const TensorCoordIterator& coord_iterator) TensorCoord& coord,
const TensorCoordStep& coord_step)
{ {
constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension(); constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension();
constexpr index_t ntransform = TensorDesc::GetNumOfTransform(); constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
...@@ -495,9 +497,8 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate( ...@@ -495,9 +497,8 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate(
auto idx_diff_hidden = make_zero_multi_index<ndim_hidden>(); auto idx_diff_hidden = make_zero_multi_index<ndim_hidden>();
// initialize visible index diff // initialize visible index diff
set_container_subset(idx_diff_hidden, set_container_subset(
TensorDesc::GetVisibleDimensionIds(), idx_diff_hidden, TensorDesc::GetVisibleDimensionIds(), coord_step.GetVisibleIndexDiff());
coord_iterator.GetVisibleIndexDiff());
// this is what needs to be updated // this is what needs to be updated
auto& idx_hidden = coord.GetHiddenIndex(); auto& idx_hidden = coord.GetHiddenIndex();
...@@ -506,13 +507,13 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate( ...@@ -506,13 +507,13 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate(
auto idx_hidden_pick_visible = auto idx_hidden_pick_visible =
get_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds()); get_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds());
idx_hidden_pick_visible += coord_iterator.GetIndexDiff(); idx_hidden_pick_visible += coord_step.GetIndexDiff();
set_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds(), idx_hidden_pick_visible); set_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds(), idx_hidden_pick_visible);
// update rest of hidden index // update rest of hidden index
static_for<ntransform - 1, -1, -1>{}([&](auto itran) { static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
if(coord_iterator.do_transforms_[itran]) if(coord_step.do_transforms_[itran])
{ {
const auto& tran = tensor_desc.GetTransforms().At(itran); const auto& tran = tensor_desc.GetTransforms().At(itran);
constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran); constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran);
...@@ -524,8 +525,8 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate( ...@@ -524,8 +525,8 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate(
MultiIndex<dims_low.Size()> idx_diff_low; MultiIndex<dims_low.Size()> idx_diff_low;
// HACK: control UpdateLowerIndex for DynamicMerge using hack // HACK: control UpdateLowerIndex for Merge using hack
constexpr index_t Hack = decltype(coord_iterator.update_lower_index_hack_)::At(itran); constexpr index_t Hack = decltype(coord_step.update_lower_index_hack_)::At(itran);
tran.UpdateLowerIndex(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number<Hack>{}); tran.UpdateLowerIndex(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number<Hack>{});
...@@ -585,11 +586,11 @@ __host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc& ...@@ -585,11 +586,11 @@ __host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc&
} }
template <typename TensorDesc> template <typename TensorDesc>
using DynamicTensorCoordinate_t = decltype(make_dynamic_tensor_coordinate( using TensorCoordinate_t = decltype(make_tensor_coordinate(
TensorDesc{}, MultiIndex<remove_cv_t<remove_reference_t<TensorDesc>>::GetNumOfDimension()>{})); TensorDesc{}, MultiIndex<remove_cv_t<remove_reference_t<TensorDesc>>::GetNumOfDimension()>{}));
template <typename TensorDesc> template <typename TensorDesc>
using DynamicTensorCoordinateIterator_t = decltype(make_dynamic_tensor_coordinate_iterator( using TensorCoordinateStep_t = decltype(make_tensor_coordinate_step(
TensorDesc{}, MultiIndex<remove_cv_t<remove_reference_t<TensorDesc>>::GetNumOfDimension()>{})); TensorDesc{}, MultiIndex<remove_cv_t<remove_reference_t<TensorDesc>>::GetNumOfDimension()>{}));
} // namespace ck } // namespace ck
......
#ifndef CK_DYNAMIC_TENSOR_DESCRIPTOR_HELPER_HPP #ifndef CK_TENSOR_DESCRIPTOR_HELPER_HPP
#define CK_DYNAMIC_TENSOR_DESCRIPTOR_HELPER_HPP #define CK_TENSOR_DESCRIPTOR_HELPER_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "dynamic_multi_index_transform_helper.hpp" #include "multi_index_transform_helper.hpp"
namespace ck { namespace ck {
...@@ -37,10 +37,9 @@ __host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengt ...@@ -37,10 +37,9 @@ __host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengt
template <typename... Lengths, template <typename... Lengths,
typename... Strides, typename... Strides,
typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false> typename enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto make_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths,
make_dynamic_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths, const Tuple<Strides...>& strides)
const Tuple<Strides...>& strides)
{ {
constexpr index_t N = sizeof...(Lengths); constexpr index_t N = sizeof...(Lengths);
...@@ -75,12 +74,12 @@ make_dynamic_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths, ...@@ -75,12 +74,12 @@ make_dynamic_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths,
calculate_element_space_size_impl(lengths, strides, Number<0>{}, Number<1>{}); calculate_element_space_size_impl(lengths, strides, Number<0>{}, Number<1>{});
#endif #endif
return DynamicTensorDescriptor<remove_cv_t<decltype(transforms)>, return TensorDescriptor<remove_cv_t<decltype(transforms)>,
remove_cv_t<decltype(low_dim_hidden_idss)>, remove_cv_t<decltype(low_dim_hidden_idss)>,
remove_cv_t<decltype(up_dim_hidden_idss)>, remove_cv_t<decltype(up_dim_hidden_idss)>,
remove_cv_t<decltype(visible_dim_hidden_ids)>, remove_cv_t<decltype(visible_dim_hidden_ids)>,
remove_cv_t<decltype(element_space_size)>>{transforms, remove_cv_t<decltype(element_space_size)>>{transforms,
element_space_size}; element_space_size};
} }
// Lengths... can be: // Lengths... can be:
...@@ -88,7 +87,7 @@ make_dynamic_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths, ...@@ -88,7 +87,7 @@ make_dynamic_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths,
// 2) Number<>, which is known at compile-time // 2) Number<>, which is known at compile-time
template <typename... Lengths> template <typename... Lengths>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
make_dynamic_naive_tensor_descriptor_packed_v2(const Tuple<Lengths...>& lengths) make_naive_tensor_descriptor_packed(const Tuple<Lengths...>& lengths)
{ {
constexpr index_t N = sizeof...(Lengths); constexpr index_t N = sizeof...(Lengths);
...@@ -103,17 +102,17 @@ make_dynamic_naive_tensor_descriptor_packed_v2(const Tuple<Lengths...>& lengths) ...@@ -103,17 +102,17 @@ make_dynamic_naive_tensor_descriptor_packed_v2(const Tuple<Lengths...>& lengths)
const auto element_space_size = container_reduce(lengths, math::multiplies_v2{}, Number<1>{}); const auto element_space_size = container_reduce(lengths, math::multiplies_v2{}, Number<1>{});
return DynamicTensorDescriptor<remove_cv_t<decltype(transforms)>, return TensorDescriptor<remove_cv_t<decltype(transforms)>,
remove_cv_t<decltype(low_dim_hidden_idss)>, remove_cv_t<decltype(low_dim_hidden_idss)>,
remove_cv_t<decltype(up_dim_hidden_idss)>, remove_cv_t<decltype(up_dim_hidden_idss)>,
remove_cv_t<decltype(visible_dim_hidden_ids)>, remove_cv_t<decltype(visible_dim_hidden_ids)>,
remove_cv_t<decltype(element_space_size)>>{transforms, remove_cv_t<decltype(element_space_size)>>{transforms,
element_space_size}; element_space_size};
} }
template <typename... Lengths, typename Align> template <typename... Lengths, typename Align>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
make_dynamic_naive_tensor_descriptor_aligned_v2(const Tuple<Lengths...>& lengths, Align align) make_naive_tensor_descriptor_aligned_v2(const Tuple<Lengths...>& lengths, Align align)
{ {
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -143,7 +142,7 @@ make_dynamic_naive_tensor_descriptor_aligned_v2(const Tuple<Lengths...>& lengths ...@@ -143,7 +142,7 @@ make_dynamic_naive_tensor_descriptor_aligned_v2(const Tuple<Lengths...>& lengths
}, },
Number<N>{}); Number<N>{});
return make_dynamic_naive_tensor_descriptor_v2(lengths, strides); return make_naive_tensor_descriptor_v2(lengths, strides);
} }
} // namespace ck } // namespace ck
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_adaptor.hpp" #include "tensor_adaptor.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_contraction_dlops.hpp" #include "threadwise_contraction_dlops.hpp"
namespace ck { namespace ck {
...@@ -22,24 +22,24 @@ namespace ck { ...@@ -22,24 +22,24 @@ namespace ck {
// 2. CThreadBuffer is StaticBuffer // 2. CThreadBuffer is StaticBuffer
// Also assume: // Also assume:
// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization) // M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
template <index_t BlockSize, template <
typename FloatA, index_t BlockSize,
typename FloatB, typename FloatA,
typename FloatC, typename FloatB,
typename AKMBlockDesc, typename FloatC,
typename BKNBlockDesc, typename AKMBlockDesc,
index_t M1PerThreadM11, typename BKNBlockDesc,
index_t N1PerThreadN11, index_t M1PerThreadM11,
index_t KPerThread, index_t N1PerThreadN11,
index_t M1N1ThreadClusterM100, index_t KPerThread,
index_t M1N1ThreadClusterN100, index_t M1N1ThreadClusterM100,
index_t M1N1ThreadClusterM101, index_t M1N1ThreadClusterN100,
index_t M1N1ThreadClusterN101, index_t M1N1ThreadClusterM101,
index_t AThreadCopyScalarPerVector_M11, index_t M1N1ThreadClusterN101,
index_t BThreadCopyScalarPerVector_N11, index_t AThreadCopyScalarPerVector_M11,
typename std::enable_if<AKMBlockDesc::IsKnownAtCompileTime() && index_t BThreadCopyScalarPerVector_N11,
BKNBlockDesc::IsKnownAtCompileTime(), typename enable_if<AKMBlockDesc::IsKnownAtCompileTime() && BKNBlockDesc::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2 struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
{ {
using AIndex = MultiIndex<3>; using AIndex = MultiIndex<3>;
...@@ -71,9 +71,9 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2 ...@@ -71,9 +71,9 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
static constexpr index_t N0 = N / N1; static constexpr index_t N0 = N / N1;
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeAKM0M1BlockDescriptor(const AKMBlockDesc& a_k_m_block_desc) MakeAKM0M1BlockDescriptor(const AKMBlockDesc& /* a_k_m_block_desc */)
{ {
const auto a_k_m0_m1_block_desc = transform_dynamic_tensor_descriptor( const auto a_k_m0_m1_block_desc = transform_tensor_descriptor(
AKMBlockDesc{}, AKMBlockDesc{},
make_tuple(make_pass_through_transform(Number<K>{}), make_tuple(make_pass_through_transform(Number<K>{}),
make_unmerge_transform(make_tuple(Number<M0>{}, Number<M1>{}))), make_unmerge_transform(make_tuple(Number<M0>{}, Number<M1>{}))),
...@@ -84,9 +84,9 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2 ...@@ -84,9 +84,9 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeBKN0N1BlockDescriptor(const BKNBlockDesc& b_k_n_block_desc) MakeBKN0N1BlockDescriptor(const BKNBlockDesc& /* b_k_n_block_desc */)
{ {
const auto b_k_n0_n1_block_desc = transform_dynamic_tensor_descriptor( const auto b_k_n0_n1_block_desc = transform_tensor_descriptor(
BKNBlockDesc{}, BKNBlockDesc{},
make_tuple(make_pass_through_transform(Number<K>{}), make_tuple(make_pass_through_transform(Number<K>{}),
make_unmerge_transform(make_tuple(Number<N0>{}, Number<N1>{}))), make_unmerge_transform(make_tuple(Number<N0>{}, Number<N1>{}))),
...@@ -194,7 +194,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2 ...@@ -194,7 +194,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
typename ABlockBuffer, typename ABlockBuffer,
typename BBlockBuffer, typename BBlockBuffer,
typename CThreadBuffer> typename CThreadBuffer>
__device__ void Run(const CM0M1N0N1ThreadDesc& c_m0_m1_n0_n1_thread_desc, __device__ void Run(const CM0M1N0N1ThreadDesc& /* c_m0_m1_n0_n1_thread_desc */,
const ABlockBuffer& a_block_buf, const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf, const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const CThreadBuffer& c_thread_buf) const
...@@ -357,34 +357,32 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2 ...@@ -357,34 +357,32 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
private: private:
// A[K, M0, M1] // A[K, M0, M1]
static constexpr auto a_k_m0_m1_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto a_k_m0_m1_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<KPerThread>{}, Number<M0>{}, Number<M1PerThreadM11>{})); make_tuple(Number<KPerThread>{}, Number<M0>{}, Number<M1PerThreadM11>{}));
// B[K, N0, N1] // B[K, N0, N1]
static constexpr auto b_k_n0_n1_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto b_k_n0_n1_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<KPerThread>{}, Number<N0>{}, Number<N1PerThreadN11>{})); make_tuple(Number<KPerThread>{}, Number<N0>{}, Number<N1PerThreadN11>{}));
using AThreadCopy = using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
ThreadwiseDynamicTensorSliceTransfer_v4<FloatA, FloatA,
FloatA, decltype(a_k_m0_m1_block_desc_),
decltype(a_k_m0_m1_block_desc_), decltype(a_k_m0_m1_thread_desc_),
decltype(a_k_m0_m1_thread_desc_), Sequence<KPerThread, 1, M1PerThreadM11>,
Sequence<KPerThread, 1, M1PerThreadM11>, Sequence<0, 1, 2>,
Sequence<0, 1, 2>, 2,
2, AThreadCopyScalarPerVector_M11,
AThreadCopyScalarPerVector_M11, 1>;
1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatB,
using BThreadCopy = FloatB,
ThreadwiseDynamicTensorSliceTransfer_v4<FloatB, decltype(b_k_n0_n1_block_desc_),
FloatB, decltype(b_k_n0_n1_thread_desc_),
decltype(b_k_n0_n1_block_desc_), Sequence<KPerThread, 1, N1PerThreadN11>,
decltype(b_k_n0_n1_thread_desc_), Sequence<0, 1, 2>,
Sequence<KPerThread, 1, N1PerThreadN11>, 2,
Sequence<0, 1, 2>, BThreadCopyScalarPerVector_N11,
2, 1>;
BThreadCopyScalarPerVector_N11,
1>;
CIndex c_thread_origin_data_idx_; CIndex c_thread_origin_data_idx_;
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_adaptor.hpp" #include "tensor_adaptor.hpp"
#include "threadwise_dynamic_tensor_slice_transfer_v2.hpp" #include "threadwise_tensor_slice_transfer_v2.hpp"
#include "threadwise_contraction_dlops.hpp" #include "threadwise_contraction_dlops.hpp"
namespace ck { namespace ck {
...@@ -38,9 +38,9 @@ template <index_t BlockSize, ...@@ -38,9 +38,9 @@ template <index_t BlockSize,
// BM10BN10ThreadClusterBN101, ...> // BM10BN10ThreadClusterBN101, ...>
index_t AThreadCopyScalarPerVector_BM11, index_t AThreadCopyScalarPerVector_BM11,
index_t BThreadCopyScalarPerVector_BN11, index_t BThreadCopyScalarPerVector_BN11,
typename std::enable_if<ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() && typename enable_if<ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(), BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2 struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
{ {
using AIndex = MultiIndex<3>; using AIndex = MultiIndex<3>;
...@@ -75,7 +75,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B ...@@ -75,7 +75,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeABlockDescriptor_BK0_BM0_BM1_BK1(const ABlockDesc_BK0_BM_BK1& a_block_desc_bk0_bm_bk1) MakeABlockDescriptor_BK0_BM0_BM1_BK1(const ABlockDesc_BK0_BM_BK1& a_block_desc_bk0_bm_bk1)
{ {
const auto a_block_bk0_bm0_bm1_bk1 = transform_dynamic_tensor_descriptor( const auto a_block_bk0_bm0_bm1_bk1 = transform_tensor_descriptor(
a_block_desc_bk0_bm_bk1, a_block_desc_bk0_bm_bk1,
make_tuple(make_pass_through_transform(Number<BK0>{}), make_tuple(make_pass_through_transform(Number<BK0>{}),
make_unmerge_transform(make_tuple(Number<BM0>{}, Number<BM1>{})), make_unmerge_transform(make_tuple(Number<BM0>{}, Number<BM1>{})),
...@@ -89,7 +89,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B ...@@ -89,7 +89,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeBBlockDescriptor_BK0_BN0_BN1_BK1(const BBlockDesc_BK0_BN_BK1& b_block_desc_bk0_bn_bk1) MakeBBlockDescriptor_BK0_BN0_BN1_BK1(const BBlockDesc_BK0_BN_BK1& b_block_desc_bk0_bn_bk1)
{ {
const auto b_block_desc_bk0_bn0_bn1_bk1 = transform_dynamic_tensor_descriptor( const auto b_block_desc_bk0_bn0_bn1_bk1 = transform_tensor_descriptor(
b_block_desc_bk0_bn_bk1, b_block_desc_bk0_bn_bk1,
make_tuple(make_pass_through_transform(Number<BK0>{}), make_tuple(make_pass_through_transform(Number<BK0>{}),
make_unmerge_transform(make_tuple(Number<BN0>{}, Number<BN1>{})), make_unmerge_transform(make_tuple(Number<BN0>{}, Number<BN1>{})),
...@@ -372,15 +372,15 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B ...@@ -372,15 +372,15 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
private: private:
// A[BK0, BM0, BM1, BK1] // A[BK0, BM0, BM1, BK1]
static constexpr auto a_thread_desc_bk0_bm0_bm1_bk1_ = static constexpr auto a_thread_desc_bk0_bm0_bm1_bk1_ =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( make_naive_tensor_descriptor_packed(make_tuple(
Number<BK0PerThread>{}, Number<BM0>{}, Number<BM1PerThreadBM11>{}, Number<BK1>{})); Number<BK0PerThread>{}, Number<BM0>{}, Number<BM1PerThreadBM11>{}, Number<BK1>{}));
// B[BK0, BN0, BN1, BK1] // B[BK0, BN0, BN1, BK1]
static constexpr auto b_thread_desc_bk0_bn0_bn1_bk1_ = static constexpr auto b_thread_desc_bk0_bn0_bn1_bk1_ =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( make_naive_tensor_descriptor_packed(make_tuple(
Number<BK0PerThread>{}, Number<BN0>{}, Number<BN1PerThreadBN11>{}, Number<BK1>{})); Number<BK0PerThread>{}, Number<BN0>{}, Number<BN1PerThreadBN11>{}, Number<BK1>{}));
using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4r1< using AThreadCopy = ThreadwiseTensorSliceTransfer_v4r1<
FloatA, FloatA,
FloatA, FloatA,
decltype(a_block_desc_bk0_bm0_bm1_bk1_), decltype(a_block_desc_bk0_bm0_bm1_bk1_),
...@@ -390,7 +390,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B ...@@ -390,7 +390,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
Sequence<1, 1, BM1PerThreadBM11, BK1>, // SrcVectorTensorLengths Sequence<1, 1, BM1PerThreadBM11, BK1>, // SrcVectorTensorLengths
Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4r1< using BThreadCopy = ThreadwiseTensorSliceTransfer_v4r1<
FloatB, FloatB,
FloatB, FloatB,
decltype(b_block_desc_bk0_bn0_bn1_bk1_), decltype(b_block_desc_bk0_bn0_bn1_bk1_),
......
...@@ -31,25 +31,24 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -31,25 +31,24 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
// HACK: fix this @Jing Zhang // HACK: fix this @Jing Zhang
static constexpr index_t KPerThreadSubC = 4; static constexpr index_t KPerThreadSubC = 4;
static constexpr auto a_thread_mtx_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto a_thread_mtx_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<EPerThreadLoop>{}, Number<KPerThreadSubC>{})); make_tuple(Number<EPerThreadLoop>{}, Number<KPerThreadSubC>{}));
static constexpr auto b_thread_mtx_ = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( static constexpr auto b_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple(
Number<EPerThreadLoop>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{})); Number<EPerThreadLoop>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
static constexpr auto c_thread_mtx_ = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( static constexpr auto c_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple(
Number<KPerThreadSubC>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{})); Number<KPerThreadSubC>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
using AThreadCopy = using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
ThreadwiseDynamicTensorSliceTransfer_v4<FloatA, FloatA,
FloatA, BlockMatrixA,
BlockMatrixA, decltype(a_thread_mtx_),
decltype(a_thread_mtx_), Sequence<EPerThreadLoop, KPerThreadSubC>,
Sequence<EPerThreadLoop, KPerThreadSubC>, Sequence<0, 1>,
Sequence<0, 1>, 1,
1, ThreadGemmADataPerRead_K,
ThreadGemmADataPerRead_K, 1>;
1>;
__device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3() __device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3()
: c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())}, : c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())},
...@@ -69,7 +68,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -69,7 +68,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
"wrong! K dimension not consistent\n"); "wrong! K dimension not consistent\n");
constexpr index_t K = BlockMatrixA{}.GetLength(I1); // A is transposed constexpr index_t K = BlockMatrixA{}.GetLength(I1); // A is transposed
constexpr index_t N = BlockMatrixB{}.GetLength(I1);
constexpr index_t H = BlockMatrixB{}.GetLength(I2); constexpr index_t H = BlockMatrixB{}.GetLength(I2);
constexpr index_t W = BlockMatrixB{}.GetLength(I3); constexpr index_t W = BlockMatrixB{}.GetLength(I3);
...@@ -121,9 +119,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -121,9 +119,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
"wrong! inconsistent type"); "wrong! inconsistent type");
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto a_block_mtx = BlockMatrixA{}; constexpr auto a_block_mtx = BlockMatrixA{};
...@@ -138,7 +133,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -138,7 +133,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
static_assert(WPerThread % WoPerThreadSubC == 0, ""); static_assert(WPerThread % WoPerThreadSubC == 0, "");
// thread A buffer for GEMM // thread A buffer for GEMM
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize()> StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize(), true>
a_thread_buf; a_thread_buf;
constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3<FloatA, constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3<FloatA,
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define CK_BLOCKWISE_GEMM_XDLOPS_HPP #define CK_BLOCKWISE_GEMM_XDLOPS_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "xdlops_gemm.hpp" #include "xdlops_gemm.hpp"
namespace ck { namespace ck {
...@@ -52,7 +52,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -52,7 +52,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
const index_t waveId = thread_id / WaveSize; const index_t waveId = thread_id / WaveSize;
const index_t laneId = thread_id % WaveSize; const index_t laneId = thread_id % WaveSize;
const index_t waveId_m = waveId / NWaves; const index_t waveId_m = waveId / NWaves;
const index_t waveId_n = waveId % NWaves;
if constexpr(xdlops_gemm.IsKReduction) if constexpr(xdlops_gemm.IsKReduction)
{ {
...@@ -73,7 +72,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -73,7 +72,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
const index_t thread_id = get_thread_local_1d_id(); const index_t thread_id = get_thread_local_1d_id();
const index_t waveId = thread_id / WaveSize; const index_t waveId = thread_id / WaveSize;
const index_t laneId = thread_id % WaveSize; const index_t laneId = thread_id % WaveSize;
const index_t waveId_m = waveId / NWaves;
const index_t waveId_n = waveId % NWaves; const index_t waveId_n = waveId % NWaves;
if constexpr(xdlops_gemm.IsKReduction) if constexpr(xdlops_gemm.IsKReduction)
...@@ -193,35 +191,35 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -193,35 +191,35 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
private: private:
// A[K, M] // A[K, M]
static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto a_thread_desc_ =
make_tuple(I1, Number<MRepeat>{}, I1, Number<K1>{})); make_naive_tensor_descriptor_packed(make_tuple(I1, Number<MRepeat>{}, I1, Number<K1>{}));
// B[K, N] // B[K, N]
static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto b_thread_desc_ =
make_tuple(I1, Number<NRepeat>{}, I1, Number<K1>{})); make_naive_tensor_descriptor_packed(make_tuple(I1, Number<NRepeat>{}, I1, Number<K1>{}));
static constexpr auto c_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto c_thread_desc_ =
make_tuple(Number<MRepeat>{}, Number<NRepeat>{})); make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatAB, using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB, FloatAB,
ABlockDesc, ABlockDesc,
decltype(a_thread_desc_), decltype(a_thread_desc_),
Sequence<1, MRepeat, 1, K1>, Sequence<1, MRepeat, 1, K1>,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
3, 3,
K1, K1,
1>; 1>;
using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatAB, using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB, FloatAB,
BBlockDesc, BBlockDesc,
decltype(b_thread_desc_), decltype(b_thread_desc_),
Sequence<1, NRepeat, 1, K1>, Sequence<1, NRepeat, 1, K1>,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
3, 3,
K1, K1,
1>; 1>;
AThreadCopy a_thread_copy_; AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_; BThreadCopy b_thread_copy_;
...@@ -272,7 +270,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline ...@@ -272,7 +270,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
const index_t waveId = thread_id / WaveSize; const index_t waveId = thread_id / WaveSize;
const index_t laneId = thread_id % WaveSize; const index_t laneId = thread_id % WaveSize;
const index_t waveId_m = waveId / NWaves; const index_t waveId_m = waveId / NWaves;
const index_t waveId_n = waveId % NWaves;
if constexpr(xdlops_gemm.IsKReduction) if constexpr(xdlops_gemm.IsKReduction)
{ {
...@@ -293,7 +290,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline ...@@ -293,7 +290,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
const index_t thread_id = get_thread_local_1d_id(); const index_t thread_id = get_thread_local_1d_id();
const index_t waveId = thread_id / WaveSize; const index_t waveId = thread_id / WaveSize;
const index_t laneId = thread_id % WaveSize; const index_t laneId = thread_id % WaveSize;
const index_t waveId_m = waveId / NWaves;
const index_t waveId_n = waveId % NWaves; const index_t waveId_n = waveId % NWaves;
if constexpr(xdlops_gemm.IsKReduction) if constexpr(xdlops_gemm.IsKReduction)
...@@ -490,35 +486,35 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline ...@@ -490,35 +486,35 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
private: private:
// A[K, M] // A[K, M]
static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto a_thread_desc_ =
make_tuple(I1, Number<MRepeat>{}, I1, Number<K1>{})); make_naive_tensor_descriptor_packed(make_tuple(I1, Number<MRepeat>{}, I1, Number<K1>{}));
// B[K, N] // B[K, N]
static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto b_thread_desc_ =
make_tuple(I1, Number<NRepeat>{}, I1, Number<K1>{})); make_naive_tensor_descriptor_packed(make_tuple(I1, Number<NRepeat>{}, I1, Number<K1>{}));
static constexpr auto c_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto c_thread_desc_ =
make_tuple(Number<MRepeat>{}, Number<NRepeat>{})); make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatAB, using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB, FloatAB,
ABlockDesc, ABlockDesc,
decltype(a_thread_desc_), decltype(a_thread_desc_),
Sequence<1, 1, 1, K1>, Sequence<1, 1, 1, K1>,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
3, 3,
1, // K1, 1, // K1,
1>; 1>;
using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatAB, using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB, FloatAB,
BBlockDesc, BBlockDesc,
decltype(b_thread_desc_), decltype(b_thread_desc_),
Sequence<1, 1, 1, K1>, Sequence<1, 1, 1, K1>,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
3, 3,
1, // K1, 1, // K1,
1>; 1>;
AThreadCopy a_thread_copy_; AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_; BThreadCopy b_thread_copy_;
......
#ifndef CK_BLOCKWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_HPP #ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_HPP
#define CK_BLOCKWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_HPP #define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp" #include "cluster_descriptor.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
namespace ck { namespace ck {
// this version does following things to avoid scratch memory issue // this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer // 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseDynamicTensorSliceTransfer_v3 does not keep reference to tensor descriptor // 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseDynamicTensorSliceTransfer_v3::Run() does not construct new tensor coordinate // 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize, template <index_t BlockSize,
InMemoryDataOperationEnum_t DstInMemOp, InMemoryDataOperationEnum_t DstInMemOp,
typename BlockSliceLengths, typename BlockSliceLengths,
...@@ -33,16 +33,16 @@ template <index_t BlockSize, ...@@ -33,16 +33,16 @@ template <index_t BlockSize,
index_t DstScalarStrideInVector, index_t DstScalarStrideInVector,
bool ThreadTransferSrcResetCoordinateAfterRun, bool ThreadTransferSrcResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun> bool ThreadTransferDstResetCoordinateAfterRun>
struct BlockwiseDynamicTensorSliceTransfer_v4 struct BlockwiseTensorSliceTransfer_v4
{ {
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension(); static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseDynamicTensorSliceTransfer_v4(const SrcDesc& src_desc, __device__ constexpr BlockwiseTensorSliceTransfer_v4(const SrcDesc& src_desc,
const Index& src_block_slice_origin, const Index& src_block_slice_origin,
const DstDesc& dst_desc, const DstDesc& dst_desc,
const Index& dst_block_slice_origin) const Index& dst_block_slice_origin)
: threadwise_transfer_( : threadwise_transfer_(
src_desc, make_zero_multi_index<nDim>(), dst_desc, make_zero_multi_index<nDim>()) src_desc, make_zero_multi_index<nDim>(), dst_desc, make_zero_multi_index<nDim>())
...@@ -77,15 +77,14 @@ struct BlockwiseDynamicTensorSliceTransfer_v4 ...@@ -77,15 +77,14 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
} }
} }
template <typename SrcBuffer, typename SrcIteratorHacks> template <typename SrcBuffer, typename SrcStepHacks>
__device__ void RunRead(const SrcDesc& src_desc, __device__ void
const SrcBuffer& src_buf, RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
const SrcIteratorHacks& src_iterator_hacks)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.RunRead(src_desc, src_buf, src_iterator_hacks); threadwise_transfer_.RunRead(src_desc, src_buf, src_step_hacks);
} }
} }
...@@ -118,18 +117,18 @@ struct BlockwiseDynamicTensorSliceTransfer_v4 ...@@ -118,18 +117,18 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
} }
} }
// SrcMoveSliceWindowIteratorHack to control index calculation move slice window // SrcMoveSliceWindowStepHack to control index calculation move slice window
template <typename SrcMoveSliceWindowIteratorHack> template <typename SrcMoveSliceWindowStepHack>
__device__ void __device__ void
MoveSrcSliceWindow(const SrcDesc& src_desc, MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& step, const Index& step,
const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack) const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.MoveSrcSliceWindow( threadwise_transfer_.MoveSrcSliceWindow(
src_desc, step, src_move_slice_window_iterator_hack); src_desc, step, src_move_slice_window_step_hack);
} }
} }
...@@ -147,22 +146,22 @@ struct BlockwiseDynamicTensorSliceTransfer_v4 ...@@ -147,22 +146,22 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
make_cluster_descriptor_v2(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor_v2(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer = using ThreadwiseTransfer =
ThreadwiseDynamicTensorSliceTransfer_v3<ThreadSliceLengths, ThreadwiseTensorSliceTransfer_v3<ThreadSliceLengths,
DstInMemOp, DstInMemOp,
SrcData, SrcData,
DstData, DstData,
SrcDesc, SrcDesc,
DstDesc, DstDesc,
SrcDimAccessOrder, SrcDimAccessOrder,
DstDimAccessOrder, DstDimAccessOrder,
SrcVectorDim, SrcVectorDim,
DstVectorDim, DstVectorDim,
SrcScalarPerVector, SrcScalarPerVector,
DstScalarPerVector, DstScalarPerVector,
SrcScalarStrideInVector, SrcScalarStrideInVector,
DstScalarStrideInVector, DstScalarStrideInVector,
ThreadTransferSrcResetCoordinateAfterRun, ThreadTransferSrcResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun>; ThreadTransferDstResetCoordinateAfterRun>;
ThreadwiseTransfer threadwise_transfer_; ThreadwiseTransfer threadwise_transfer_;
}; };
......
#ifndef CK_BLOCKWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_V2_HPP #ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V2_HPP
#define CK_BLOCKWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_V2_HPP #define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V2_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp" #include "cluster_descriptor.hpp"
#include "threadwise_dynamic_tensor_slice_transfer_v2.hpp" #include "threadwise_tensor_slice_transfer_v2.hpp"
namespace ck { namespace ck {
// this version does following things to avoid scratch memory issue // this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer // 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseDynamicTensorSliceTransfer_v3 does not keep reference to tensor descriptor // 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseDynamicTensorSliceTransfer_v3::Run() does not construct new tensor coordinate // 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize, template <index_t BlockSize,
InMemoryDataOperationEnum_t DstInMemOp, InMemoryDataOperationEnum_t DstInMemOp,
typename BlockSliceLengths, typename BlockSliceLengths,
...@@ -31,17 +31,16 @@ template <index_t BlockSize, ...@@ -31,17 +31,16 @@ template <index_t BlockSize,
typename DstVectorTensorContiguousDimOrder, typename DstVectorTensorContiguousDimOrder,
bool ThreadTransferSrcResetCoordinateAfterRun, bool ThreadTransferSrcResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun> bool ThreadTransferDstResetCoordinateAfterRun>
struct BlockwiseDynamicTensorSliceTransfer_v4r1 struct BlockwiseTensorSliceTransfer_v4r1
{ {
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension(); static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseDynamicTensorSliceTransfer_v4r1( __device__ constexpr BlockwiseTensorSliceTransfer_v4r1(const SrcDesc& src_desc,
const SrcDesc& src_desc, const Index& src_block_slice_origin,
const Index& src_block_slice_origin, const DstDesc& dst_desc,
const DstDesc& dst_desc, const Index& dst_block_slice_origin)
const Index& dst_block_slice_origin)
: threadwise_transfer_( : threadwise_transfer_(
src_desc, make_zero_multi_index<nDim>(), dst_desc, make_zero_multi_index<nDim>()) src_desc, make_zero_multi_index<nDim>(), dst_desc, make_zero_multi_index<nDim>())
...@@ -76,15 +75,14 @@ struct BlockwiseDynamicTensorSliceTransfer_v4r1 ...@@ -76,15 +75,14 @@ struct BlockwiseDynamicTensorSliceTransfer_v4r1
} }
} }
template <typename SrcBuffer, typename SrcIteratorHacks> template <typename SrcBuffer, typename SrcStepHacks>
__device__ void RunRead(const SrcDesc& src_desc, __device__ void
const SrcBuffer& src_buf, RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
const SrcIteratorHacks& src_iterator_hacks)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.RunRead(src_desc, src_buf, src_iterator_hacks); threadwise_transfer_.RunRead(src_desc, src_buf, src_step_hacks);
} }
} }
...@@ -107,18 +105,18 @@ struct BlockwiseDynamicTensorSliceTransfer_v4r1 ...@@ -107,18 +105,18 @@ struct BlockwiseDynamicTensorSliceTransfer_v4r1
} }
} }
// SrcMoveSliceWindowIteratorHack to control index calculation move slice window // SrcMoveSliceWindowStepHack to control index calculation move slice window
template <typename SrcMoveSliceWindowIteratorHack> template <typename SrcMoveSliceWindowStepHack>
__device__ void __device__ void
MoveSrcSliceWindow(const SrcDesc& src_desc, MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& step, const Index& step,
const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack) const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.MoveSrcSliceWindow( threadwise_transfer_.MoveSrcSliceWindow(
src_desc, step, src_move_slice_window_iterator_hack); src_desc, step, src_move_slice_window_step_hack);
} }
} }
...@@ -136,20 +134,20 @@ struct BlockwiseDynamicTensorSliceTransfer_v4r1 ...@@ -136,20 +134,20 @@ struct BlockwiseDynamicTensorSliceTransfer_v4r1
make_cluster_descriptor_v2(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor_v2(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer = using ThreadwiseTransfer =
ThreadwiseDynamicTensorSliceTransfer_v3r1<ThreadSliceLengths, ThreadwiseTensorSliceTransfer_v3r1<ThreadSliceLengths,
DstInMemOp, DstInMemOp,
SrcData, SrcData,
DstData, DstData,
SrcDesc, SrcDesc,
DstDesc, DstDesc,
SrcDimAccessOrder, SrcDimAccessOrder,
DstDimAccessOrder, DstDimAccessOrder,
SrcVectorTensorLengths, SrcVectorTensorLengths,
DstVectorTensorLengths, DstVectorTensorLengths,
SrcVectorTensorContiguousDimOrder, SrcVectorTensorContiguousDimOrder,
DstVectorTensorContiguousDimOrder, DstVectorTensorContiguousDimOrder,
ThreadTransferSrcResetCoordinateAfterRun, ThreadTransferSrcResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun>; ThreadTransferDstResetCoordinateAfterRun>;
ThreadwiseTransfer threadwise_transfer_; ThreadwiseTransfer threadwise_transfer_;
}; };
......
#ifndef CK_GRIDWISE_DYNAMIC_CONTRACTION_DLOPS_V1R2_HPP #ifndef CK_GRIDWISE_CONTRACTION_DLOPS_V1R2_HPP
#define CK_GRIDWISE_DYNAMIC_CONTRACTION_DLOPS_V1R2_HPP #define CK_GRIDWISE_CONTRACTION_DLOPS_V1R2_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_multi_index_transform_helper.hpp" #include "multi_index_transform_helper.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_dlops_v2r3.hpp" #include "blockwise_gemm_dlops_v2r3.hpp"
#include "blockwise_dynamic_tensor_slice_transfer_v2.hpp" #include "blockwise_tensor_slice_transfer_v2.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_set.hpp" #include "threadwise_tensor_slice_set.hpp"
namespace ck { namespace ck {
...@@ -25,7 +25,7 @@ __global__ void ...@@ -25,7 +25,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_dynamic_contraction_dlops_v1r2( kernel_contraction_dlops_v1r2(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
...@@ -84,12 +84,12 @@ template <index_t BlockSize, ...@@ -84,12 +84,12 @@ template <index_t BlockSize,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector, index_t CThreadTransferDstScalarPerVector,
typename AGridIteratorHacks, typename AGridStepHacks,
typename BGridIteratorHacks, typename BGridStepHacks,
typename CGridIteratorHacks, typename CGridStepHacks,
typename AGridMoveSliceWindowIteratorHacks, typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowIteratorHacks> typename BGridMoveSliceWindowStepHacks>
struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1 struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -110,17 +110,15 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 ...@@ -110,17 +110,15 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 = constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 = make_naive_tensor_descriptor_aligned_v2(
make_dynamic_naive_tensor_descriptor_aligned_v2( make_tuple(Number<GK0PerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}, GK1),
make_tuple(Number<GK0PerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}, GK1), max_lds_align);
max_lds_align);
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 = constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 = make_naive_tensor_descriptor_aligned_v2(
make_dynamic_naive_tensor_descriptor_aligned_v2( make_tuple(Number<GK0PerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}, GK1),
make_tuple(Number<GK0PerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}, GK1), max_lds_align);
max_lds_align);
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size = math::integer_least_multiple( constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
...@@ -201,7 +199,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 ...@@ -201,7 +199,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
const auto GM11 = Number<GM1PerBlockGM11>{}; const auto GM11 = Number<GM1PerBlockGM11>{};
const auto GM10 = GM1 / GM11; const auto GM10 = GM1 / GM11;
const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = transform_dynamic_tensor_descriptor( const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = transform_tensor_descriptor(
a_grid_desc_gk0_gm0_gm1_gk1, a_grid_desc_gk0_gm0_gm1_gk1,
make_tuple(make_pass_through_transform(GK0), make_tuple(make_pass_through_transform(GK0),
make_pass_through_transform(GM0), make_pass_through_transform(GM0),
...@@ -222,7 +220,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 ...@@ -222,7 +220,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
const auto GN11 = Number<GN1PerBlockGN11>{}; const auto GN11 = Number<GN1PerBlockGN11>{};
const auto GN10 = GN1 / GN11; const auto GN10 = GN1 / GN11;
const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = transform_dynamic_tensor_descriptor( const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = transform_tensor_descriptor(
b_grid_desc_gk0_gn0_gn1_gk1, b_grid_desc_gk0_gn0_gn1_gk1,
make_tuple(make_pass_through_transform(GK0), make_tuple(make_pass_through_transform(GK0),
make_pass_through_transform(GN0), make_pass_through_transform(GN0),
...@@ -259,7 +257,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 ...@@ -259,7 +257,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
constexpr auto BM0 = BM / BM1; constexpr auto BM0 = BM / BM1;
constexpr auto BN0 = BN / BN1; constexpr auto BN0 = BN / BN1;
const auto c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc = transform_dynamic_tensor_descriptor( const auto c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc = transform_tensor_descriptor(
c_grid_desc_gm0_gm1_gn0_gn1, c_grid_desc_gm0_gm1_gn0_gn1,
make_tuple(make_pass_through_transform(GM0), make_tuple(make_pass_through_transform(GM0),
make_unmerge_transform(make_tuple(GM10, GM11)), make_unmerge_transform(make_tuple(GM10, GM11)),
...@@ -268,7 +266,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 ...@@ -268,7 +266,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}));
const auto c_gm10_bm_gn10_bn_grid_desc = transform_dynamic_tensor_descriptor( const auto c_gm10_bm_gn10_bn_grid_desc = transform_tensor_descriptor(
c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc, c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc,
make_tuple(make_pass_through_transform(GM10), make_tuple(make_pass_through_transform(GM10),
make_merge_transform(make_tuple(GM0, GM11)), make_merge_transform(make_tuple(GM0, GM11)),
...@@ -277,7 +275,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 ...@@ -277,7 +275,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<4>{}, Sequence<3, 5>{}), make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = transform_dynamic_tensor_descriptor( const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = transform_tensor_descriptor(
c_gm10_bm_gn10_bn_grid_desc, c_gm10_bm_gn10_bn_grid_desc,
make_tuple(make_pass_through_transform(GM10), make_tuple(make_pass_through_transform(GM10),
make_unmerge_transform(make_tuple(BM0, BM1)), make_unmerge_transform(make_tuple(BM0, BM1)),
...@@ -356,26 +354,24 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 ...@@ -356,26 +354,24 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 = constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 = make_naive_tensor_descriptor_aligned_v2(
make_dynamic_naive_tensor_descriptor_aligned_v2( make_tuple(Number<GK0PerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}, GK1),
make_tuple(Number<GK0PerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}, GK1), max_lds_align);
max_lds_align);
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 = constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 = make_naive_tensor_descriptor_aligned_v2(
make_dynamic_naive_tensor_descriptor_aligned_v2( make_tuple(Number<GK0PerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}, GK1),
make_tuple(Number<GK0PerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}, GK1), max_lds_align);
max_lds_align);
// A matrix in LDS memory for blockwise GEMM // A matrix in LDS memory for blockwise GEMM
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto a_block_desc_gk0_bm_gk1 = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto a_block_desc_gk0_bm_gk1 = make_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<GK0PerBlock>{}, GM0 * Number<GM1PerBlockGM11>{}, GK1), max_lds_align); make_tuple(Number<GK0PerBlock>{}, GM0 * Number<GM1PerBlockGM11>{}, GK1), max_lds_align);
// B matrix in LDS memory for blockwise GEMM // B matrix in LDS memory for blockwise GEMM
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto b_block_desc_gk0_bn_gk1 = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto b_block_desc_gk0_bn_gk1 = make_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<GK0PerBlock>{}, GN0 * Number<GN1PerBlockGN11>{}, GK1), max_lds_align); make_tuple(Number<GK0PerBlock>{}, GN0 * Number<GN1PerBlockGN11>{}, GK1), max_lds_align);
static_assert(a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize() == static_assert(a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize() ==
...@@ -385,7 +381,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 ...@@ -385,7 +381,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
"wrong!"); "wrong!");
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1< auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1<
BlockSize, BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<GK0PerBlock, GM0, 1, GM1PerBlockGM11, GK1.value>, Sequence<GK0PerBlock, GM0, 1, GM1PerBlockGM11, GK1.value>,
...@@ -409,7 +405,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 ...@@ -409,7 +405,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
make_multi_index(0, 0, 0, 0, 0)); make_multi_index(0, 0, 0, 0, 0));
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1< auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1<
BlockSize, BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<GK0PerBlock, GN0, 1, GN1PerBlockGN11, GK1.value>, Sequence<GK0PerBlock, GN0, 1, GN1PerBlockGN11, GK1.value>,
...@@ -457,9 +453,8 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 ...@@ -457,9 +453,8 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
constexpr auto c_thread_tensor_lengths_bm0_bm1_bn0_bn1 = constexpr auto c_thread_tensor_lengths_bm0_bm1_bn0_bn1 =
decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1(); decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
constexpr auto c_thread_desc_bm0_bm1_bn0_bn1 = constexpr auto c_thread_desc_bm0_bm1_bn0_bn1 = make_naive_tensor_descriptor_packed(
make_dynamic_naive_tensor_descriptor_packed_v2( sequence_to_tuple_of_number(c_thread_tensor_lengths_bm0_bm1_bn0_bn1));
sequence_to_tuple_of_number(c_thread_tensor_lengths_bm0_bm1_bn0_bn1));
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size = math::integer_least_multiple( constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
...@@ -475,9 +470,9 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 ...@@ -475,9 +470,9 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>( auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>(
c_thread_desc_bm0_bm1_bn0_bn1.GetElementSpaceSize()); c_thread_desc_bm0_bm1_bn0_bn1.GetElementSpaceSize());
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc, ThreadwiseTensorSliceSet_v1<FloatAcc,
decltype(c_thread_desc_bm0_bm1_bn0_bn1), decltype(c_thread_desc_bm0_bm1_bn0_bn1),
decltype(c_thread_tensor_lengths_bm0_bm1_bn0_bn1)>{} decltype(c_thread_tensor_lengths_bm0_bm1_bn0_bn1)>{}
.Run(c_thread_desc_bm0_bm1_bn0_bn1, .Run(c_thread_desc_bm0_bm1_bn0_bn1,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
c_thread_buf, c_thread_buf,
...@@ -501,9 +496,9 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 ...@@ -501,9 +496,9 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{}); a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{}); b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_even_buf); a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_even_buf);
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_even_buf); b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_even_buf);
...@@ -520,18 +515,18 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 ...@@ -520,18 +515,18 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
// even iteration // even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
a_block_slice_copy_step, a_block_slice_copy_step,
AGridMoveSliceWindowIteratorHacks{}); AGridMoveSliceWindowStepHacks{});
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
b_block_slice_copy_step, b_block_slice_copy_step,
BGridMoveSliceWindowIteratorHacks{}); BGridMoveSliceWindowStepHacks{});
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{}); a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{}); b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(c_thread_desc_bm0_bm1_bn0_bn1, blockwise_gemm.Run(c_thread_desc_bm0_bm1_bn0_bn1,
...@@ -546,18 +541,18 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 ...@@ -546,18 +541,18 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
// odd iteration // odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
a_block_slice_copy_step, a_block_slice_copy_step,
AGridMoveSliceWindowIteratorHacks{}); AGridMoveSliceWindowStepHacks{});
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
b_block_slice_copy_step, b_block_slice_copy_step,
BGridMoveSliceWindowIteratorHacks{}); BGridMoveSliceWindowStepHacks{});
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{}); a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{}); b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run( blockwise_gemm.Run(
...@@ -576,18 +571,18 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 ...@@ -576,18 +571,18 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
{ {
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
a_block_slice_copy_step, a_block_slice_copy_step,
AGridMoveSliceWindowIteratorHacks{}); AGridMoveSliceWindowStepHacks{});
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
b_block_slice_copy_step, b_block_slice_copy_step,
BGridMoveSliceWindowIteratorHacks{}); BGridMoveSliceWindowStepHacks{});
__syncthreads(); __syncthreads();
// LDS double buffer: load last data from device mem // LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{}); a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{}); b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run( blockwise_gemm.Run(
...@@ -615,7 +610,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 ...@@ -615,7 +610,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
// output: register to global memory // output: register to global memory
{ {
constexpr auto c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1 = constexpr auto c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1 =
make_dynamic_naive_tensor_descriptor_packed_v2( make_naive_tensor_descriptor_packed(
make_tuple(I1, make_tuple(I1,
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I0]>{}, Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I0]>{},
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I1]>{}, Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I1]>{},
...@@ -627,7 +622,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 ...@@ -627,7 +622,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
get_thread_local_1d_id()); get_thread_local_1d_id());
ThreadwiseDynamicTensorSliceTransfer_v1r3< ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc, FloatAcc,
FloatC, FloatC,
decltype(c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1), decltype(c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1),
...@@ -655,7 +650,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 ...@@ -655,7 +650,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
c_thread_buf, c_thread_buf,
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
c_grid_buf, c_grid_buf,
CGridIteratorHacks{}); CGridStepHacks{});
} }
} }
}; };
......
#ifndef CK_GRIDWISE_DYNAMIC_GEMM_DLOPS_V1R2_HPP #ifndef CK_GRIDWISE_GEMM_DLOPS_V1R2_HPP
#define CK_GRIDWISE_DYNAMIC_GEMM_DLOPS_V1R2_HPP #define CK_GRIDWISE_GEMM_DLOPS_V1R2_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_multi_index_transform_helper.hpp" #include "multi_index_transform_helper.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_dlops_v2r2.hpp" #include "blockwise_gemm_dlops_v2r2.hpp"
#include "blockwise_dynamic_tensor_slice_transfer.hpp" #include "blockwise_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_set.hpp" #include "threadwise_tensor_slice_set.hpp"
namespace ck { namespace ck {
...@@ -26,7 +26,7 @@ __global__ void ...@@ -26,7 +26,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_dynamic_gemm_dlops_v1r2( kernel_gemm_dlops_v1r2(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
...@@ -68,28 +68,27 @@ __global__ void ...@@ -68,28 +68,27 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_dynamic_gemm_dlops_v1r2( kernel_gemm_dlops_v1r2(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid,
FloatC* __restrict__ p_c_grid, const void CONSTANT* p_a_k_m0_m1_grid_desc,
const void CONSTANT* p_a_k_m0_m1_grid_desc, const void CONSTANT* p_b_k_n0_n1_grid_desc,
const void CONSTANT* p_b_k_n0_n1_grid_desc, const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc, const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
{ {
// first cast void CONSTANT void* to void* // first cast void CONSTANT void* to void*
// second cast void* to Desc* // second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4) // the copy constructor of tensor descriptor doesn't take address_space(4)
const auto a_k_m0_m1_grid_desc = const auto a_k_m0_m1_grid_desc = *reinterpret_cast<const AKM0M1GridDesc*>(
*reinterpret_cast<const AKM0M1GridDesc*>((const void*)p_a_k_m0_m1_grid_desc); cast_pointer_to_generic_address_space(p_a_k_m0_m1_grid_desc));
const auto b_k_n0_n1_grid_desc = const auto b_k_n0_n1_grid_desc = *reinterpret_cast<const BKN0N1GridDesc*>(
*reinterpret_cast<const BKN0N1GridDesc*>((const void*)p_b_k_n0_n1_grid_desc); cast_pointer_to_generic_address_space(p_b_k_n0_n1_grid_desc));
const auto c_m0_m10_m11_n0_n10_n11_grid_desc = const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
*reinterpret_cast<const CM0M10M11N0N10N11GridDesc*>( *reinterpret_cast<const CM0M10M11N0N10N11GridDesc*>(
(const void*)p_c_m0_m10_m11_n0_n10_n11_grid_desc); cast_pointer_to_generic_address_space(p_c_m0_m10_m11_n0_n10_n11_grid_desc));
const auto c_blockid_to_m0_n0_block_cluster_adaptor = const auto c_blockid_to_m0_n0_block_cluster_adaptor =
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>( *reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
(const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor); cast_pointer_to_generic_address_space(p_c_blockid_to_m0_n0_block_cluster_adaptor));
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
...@@ -146,12 +145,12 @@ template <index_t BlockSize, ...@@ -146,12 +145,12 @@ template <index_t BlockSize,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector, index_t CThreadTransferDstScalarPerVector,
typename AGridIteratorHacks, typename AGridStepHacks,
typename BGridIteratorHacks, typename BGridStepHacks,
typename CGridIteratorHacks, typename CGridStepHacks,
typename AGridMoveSliceWindowIteratorHacks, typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowIteratorHacks> typename BGridMoveSliceWindowStepHacks>
struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 struct GridwiseGemmDlops_km_kn_mn_v1r2
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -167,12 +166,12 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 ...@@ -167,12 +166,12 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}), max_lds_align); make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}), max_lds_align); make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}), max_lds_align);
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
...@@ -230,7 +229,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 ...@@ -230,7 +229,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
const auto M1 = Number<MPerBlockM1>{}; const auto M1 = Number<MPerBlockM1>{};
const auto M0 = M / M1; const auto M0 = M / M1;
const auto a_k_m0_m1_grid_desc = transform_dynamic_tensor_descriptor( const auto a_k_m0_m1_grid_desc = transform_tensor_descriptor(
a_k_m_grid_desc, a_k_m_grid_desc,
make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(M0, M1))), make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(M0, M1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
...@@ -248,7 +247,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 ...@@ -248,7 +247,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
const auto N1 = Number<NPerBlockN1>{}; const auto N1 = Number<NPerBlockN1>{};
const auto N0 = N / N1; const auto N0 = N / N1;
const auto b_k_n0_n1_grid_desc = transform_dynamic_tensor_descriptor( const auto b_k_n0_n1_grid_desc = transform_tensor_descriptor(
b_k_n_grid_desc, b_k_n_grid_desc,
make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(N0, N1))), make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(N0, N1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
...@@ -277,7 +276,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 ...@@ -277,7 +276,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
constexpr auto M10 = M1 / M11; constexpr auto M10 = M1 / M11;
constexpr auto N10 = N1 / N11; constexpr auto N10 = N1 / N11;
const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_dynamic_tensor_descriptor( const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_tensor_descriptor(
c_m_n_grid_desc, c_m_n_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)), make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
make_unmerge_transform(make_tuple(N0, N10, N11))), make_unmerge_transform(make_tuple(N0, N10, N11))),
...@@ -352,75 +351,75 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 ...@@ -352,75 +351,75 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}), max_lds_align); make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}), max_lds_align); make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}), max_lds_align);
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto a_k_m0_m1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto a_k_m0_m1_block_desc = make_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, I1, Number<MPerBlockM1>{}), max_lds_align); make_tuple(Number<KPerBlock>{}, I1, Number<MPerBlockM1>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto b_k_n0_n1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto b_k_n0_n1_block_desc = make_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, I1, Number<NPerBlockN1>{}), max_lds_align); make_tuple(Number<KPerBlock>{}, I1, Number<NPerBlockN1>{}), max_lds_align);
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<KPerBlock, 1, MPerBlockM1>, Sequence<KPerBlock, 1, MPerBlockM1>,
ABlockTransferThreadSliceLengths_K_M0_M1, ABlockTransferThreadSliceLengths_K_M0_M1,
ABlockTransferThreadClusterLengths_K_M0_M1, ABlockTransferThreadClusterLengths_K_M0_M1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(a_k_m0_m1_grid_desc), decltype(a_k_m0_m1_grid_desc),
decltype(a_k_m0_m1_block_desc), decltype(a_k_m0_m1_block_desc),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>, Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
2, 2,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M1, ABlockTransferDstScalarPerVector_M1,
1, 1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true>(a_k_m0_m1_grid_desc, true>(a_k_m0_m1_grid_desc,
make_multi_index(0, im0, 0), make_multi_index(0, im0, 0),
a_k_m0_m1_block_desc, a_k_m0_m1_block_desc,
make_multi_index(0, 0, 0)); make_multi_index(0, 0, 0));
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = auto b_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<KPerBlock, 1, NPerBlockN1>, Sequence<KPerBlock, 1, NPerBlockN1>,
BBlockTransferThreadSliceLengths_K_N0_N1, BBlockTransferThreadSliceLengths_K_N0_N1,
BBlockTransferThreadClusterLengths_K_N0_N1, BBlockTransferThreadClusterLengths_K_N0_N1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(b_k_n0_n1_grid_desc), decltype(b_k_n0_n1_grid_desc),
decltype(b_k_n0_n1_block_desc), decltype(b_k_n0_n1_block_desc),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>, Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
2, 2,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N1, BBlockTransferDstScalarPerVector_N1,
1, 1,
1, 1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true>(b_k_n0_n1_grid_desc, true>(b_k_n0_n1_grid_desc,
make_multi_index(0, in0, 0), make_multi_index(0, in0, 0),
b_k_n0_n1_block_desc, b_k_n0_n1_block_desc,
make_multi_index(0, 0, 0)); make_multi_index(0, 0, 0));
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
...@@ -447,9 +446,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 ...@@ -447,9 +446,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths(); decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths();
constexpr auto c_m10_m11_n10_n11_thread_desc = constexpr auto c_m10_m11_n10_n11_thread_desc = make_naive_tensor_descriptor_packed(
make_dynamic_naive_tensor_descriptor_packed_v2( sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size = constexpr auto a_block_aligned_space_size =
...@@ -465,9 +463,9 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 ...@@ -465,9 +463,9 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>( auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>(
c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize()); c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize());
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc, ThreadwiseTensorSliceSet_v1<FloatAcc,
decltype(c_m10_m11_n10_n11_thread_desc), decltype(c_m10_m11_n10_n11_thread_desc),
decltype(c_m10_m11_n10_n11_thread_tensor_lengths)>{} decltype(c_m10_m11_n10_n11_thread_tensor_lengths)>{}
.Run(c_m10_m11_n10_n11_thread_desc, .Run(c_m10_m11_n10_n11_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
c_thread_buf, c_thread_buf,
...@@ -477,15 +475,15 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 ...@@ -477,15 +475,15 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy // hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k_m0_m1_global_iterator_hacks = AGridIteratorHacks{}; constexpr auto a_k_m0_m1_global_step_hacks = AGridStepHacks{};
constexpr auto b_k_n0_n1_global_iterator_hacks = BGridIteratorHacks{}; constexpr auto b_k_n0_n1_global_step_hacks = BGridStepHacks{};
// hack to control index calculation when move slice window for A and B matrix for // hack to control index calculation when move slice window for A and B matrix for
// threadwise copy // threadwise copy
constexpr auto a_k_m0_m1_global_move_slice_window_iterator_hack = constexpr auto a_k_m0_m1_global_move_slice_window_step_hack =
AGridMoveSliceWindowIteratorHacks{}; AGridMoveSliceWindowStepHacks{};
constexpr auto b_k_n0_n1_global_move_slice_window_iterator_hack = constexpr auto b_k_n0_n1_global_move_slice_window_step_hack =
BGridMoveSliceWindowIteratorHacks{}; BGridMoveSliceWindowStepHacks{};
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_a_block_double, a_k_m0_m1_block_desc.GetElementSpaceSize()); p_a_block_double, a_k_m0_m1_block_desc.GetElementSpaceSize());
...@@ -502,9 +500,9 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 ...@@ -502,9 +500,9 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks); a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks); b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_even_buf); a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_even_buf); b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_even_buf);
...@@ -519,22 +517,20 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 ...@@ -519,22 +517,20 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
do do
{ {
// even iteration // even iteration
a_blockwise_copy.MoveSrcSliceWindow( a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc,
a_k_m0_m1_grid_desc, a_block_slice_copy_step,
a_block_slice_copy_step, a_k_m0_m1_global_move_slice_window_step_hack);
a_k_m0_m1_global_move_slice_window_iterator_hack); b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc,
b_blockwise_copy.MoveSrcSliceWindow( b_block_slice_copy_step,
b_k_n0_n1_grid_desc, b_k_n0_n1_global_move_slice_window_step_hack);
b_block_slice_copy_step,
b_k_n0_n1_global_move_slice_window_iterator_hack);
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks); a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks); b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc, blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc,
...@@ -547,22 +543,20 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 ...@@ -547,22 +543,20 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_odd_buf); b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_odd_buf);
// odd iteration // odd iteration
a_blockwise_copy.MoveSrcSliceWindow( a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc,
a_k_m0_m1_grid_desc, a_block_slice_copy_step,
a_block_slice_copy_step, a_k_m0_m1_global_move_slice_window_step_hack);
a_k_m0_m1_global_move_slice_window_iterator_hack); b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc,
b_blockwise_copy.MoveSrcSliceWindow( b_block_slice_copy_step,
b_k_n0_n1_grid_desc, b_k_n0_n1_global_move_slice_window_step_hack);
b_block_slice_copy_step,
b_k_n0_n1_global_move_slice_window_iterator_hack);
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks); a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks); b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run( blockwise_gemm.Run(
...@@ -581,18 +575,18 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 ...@@ -581,18 +575,18 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
{ {
a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc, a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc,
a_block_slice_copy_step, a_block_slice_copy_step,
a_k_m0_m1_global_move_slice_window_iterator_hack); a_k_m0_m1_global_move_slice_window_step_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc, b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc,
b_block_slice_copy_step, b_block_slice_copy_step,
b_k_n0_n1_global_move_slice_window_iterator_hack); b_k_n0_n1_global_move_slice_window_step_hack);
__syncthreads(); __syncthreads();
// LDS double buffer: load last data from device mem // LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks); a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks); b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run( blockwise_gemm.Run(
...@@ -619,19 +613,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 ...@@ -619,19 +613,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
// output: register to global memory // output: register to global memory
{ {
constexpr index_t M11 =
M1PerThreadM111 * M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101;
constexpr index_t N11 =
N1PerThreadN111 * M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101;
constexpr index_t M10 = MPerBlockM1 / M11;
constexpr index_t N10 = NPerBlockN1 / N11;
constexpr index_t M111 = M1PerThreadM111;
constexpr index_t N111 = N1PerThreadN111;
constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc = constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2( make_naive_tensor_descriptor_packed(
make_tuple(I1, make_tuple(I1,
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{}, Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I1]>{}, Number<c_m10_m11_n10_n11_thread_tensor_lengths[I1]>{},
...@@ -642,7 +625,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 ...@@ -642,7 +625,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
const auto c_m10_m11_n10_n11_thread_origin_idx_on_block = const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
blockwise_gemm.CalculateCM0M1N0N1ThreadOriginOnBlock(get_thread_local_1d_id()); blockwise_gemm.CalculateCM0M1N0N1ThreadOriginOnBlock(get_thread_local_1d_id());
ThreadwiseDynamicTensorSliceTransfer_v1r3< ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc, FloatAcc,
FloatC, FloatC,
decltype(c_m0_m10_m11_n0_n10_n11_thread_desc), decltype(c_m0_m10_m11_n0_n10_n11_thread_desc),
...@@ -670,7 +653,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 ...@@ -670,7 +653,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
c_thread_buf, c_thread_buf,
c_m0_m10_m11_n0_n10_n11_grid_desc, c_m0_m10_m11_n0_n10_n11_grid_desc,
c_grid_buf, c_grid_buf,
CGridIteratorHacks{}); CGridStepHacks{});
} }
} }
}; };
......
#ifndef CK_GRIDWISE_DYNAMIC_GEMM_V1R3_HPP #ifndef CK_GRIDWISE_GEMM_V1R3_HPP
#define CK_GRIDWISE_DYNAMIC_GEMM_V1R3_HPP #define CK_GRIDWISE_GEMM_V1R3_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_multi_index_transform_helper.hpp" #include "multi_index_transform_helper.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_dlops_v2r3.hpp" #include "blockwise_gemm_dlops_v2r3.hpp"
#include "blockwise_dynamic_tensor_slice_transfer_v2.hpp" #include "blockwise_tensor_slice_transfer_v2.hpp"
#include "threadwise_dynamic_tensor_slice_transfer_v2.hpp" #include "threadwise_tensor_slice_transfer_v2.hpp"
#include "threadwise_dynamic_tensor_slice_set.hpp" #include "threadwise_tensor_slice_set.hpp"
namespace ck { namespace ck {
...@@ -26,7 +26,7 @@ __global__ void ...@@ -26,7 +26,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_dynamic_gemm_dlops_v1r3( kernel_gemm_dlops_v1r3(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
...@@ -68,28 +68,27 @@ __global__ void ...@@ -68,28 +68,27 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_dynamic_gemm_dlops_v1r3( kernel_gemm_dlops_v1r3(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid,
FloatC* __restrict__ p_c_grid, const void CONSTANT* p_a_k0_m0_m1_k1_grid_desc,
const void CONSTANT* p_a_k0_m0_m1_k1_grid_desc, const void CONSTANT* p_b_k0_n0_n1_k1_grid_desc,
const void CONSTANT* p_b_k0_n0_n1_k1_grid_desc, const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc, const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
{ {
// first cast void CONSTANT void* to void* // first cast void CONSTANT void* to void*
// second cast void* to Desc* // second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4) // the copy constructor of tensor descriptor doesn't take address_space(4)
const auto a_k0_m0_m1_k1_grid_desc = const auto a_k0_m0_m1_k1_grid_desc = *reinterpret_cast<const AK0M0M1K1GridDesc*>(
*reinterpret_cast<const AK0M0M1K1GridDesc*>((const void*)p_a_k0_m0_m1_k1_grid_desc); cast_pointer_to_generic_address_space(p_a_k0_m0_m1_k1_grid_desc));
const auto b_k0_n0_n1_k1_grid_desc = const auto b_k0_n0_n1_k1_grid_desc = *reinterpret_cast<const BK0N0N1K1GridDesc*>(
*reinterpret_cast<const BK0N0N1K1GridDesc*>((const void*)p_b_k0_n0_n1_k1_grid_desc); cast_pointer_to_generic_address_space(p_b_k0_n0_n1_k1_grid_desc));
const auto c_m0_m10_m11_n0_n10_n11_grid_desc = const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
*reinterpret_cast<const CM0M10M11N0N10N11GridDesc*>( *reinterpret_cast<const CM0M10M11N0N10N11GridDesc*>(
(const void*)p_c_m0_m10_m11_n0_n10_n11_grid_desc); cast_pointer_to_generic_address_space(p_c_m0_m10_m11_n0_n10_n11_grid_desc));
const auto c_blockid_to_m0_n0_block_cluster_adaptor = const auto c_blockid_to_m0_n0_block_cluster_adaptor =
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>( *reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
(const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor); cast_pointer_to_generic_address_space(p_c_blockid_to_m0_n0_block_cluster_adaptor));
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
...@@ -142,12 +141,12 @@ template <index_t BlockSize, ...@@ -142,12 +141,12 @@ template <index_t BlockSize,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector, index_t CThreadTransferDstScalarPerVector,
typename AGridIteratorHacks, typename AGridStepHacks,
typename BGridIteratorHacks, typename BGridStepHacks,
typename CGridIteratorHacks, typename CGridStepHacks,
typename AGridMoveSliceWindowIteratorHacks, typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowIteratorHacks> typename BGridMoveSliceWindowStepHacks>
struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 struct GridwiseGemmDlops_km_kn_mn_v1r3
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -164,12 +163,12 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 ...@@ -164,12 +163,12 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
// TODO: check alignment // TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}, K1), max_lds_align); make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}, K1), max_lds_align);
// TODO: check alignment // TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}, K1), max_lds_align); make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}, K1), max_lds_align);
// TODO: check alignment // TODO: check alignment
...@@ -191,12 +190,12 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 ...@@ -191,12 +190,12 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
const auto M = a_k0_m_k1_grid_desc.GetLength(I1); const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
const auto N = b_k0_n_k1_grid_desc.GetLength(I1); const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
const auto K1 = a_k0_m_k1_grid_desc.GetLength(I2);
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
K0 == b_k0_n_k1_grid_desc.GetLength(I0) && K0 == b_k0_n_k1_grid_desc.GetLength(I0) &&
K1 == a_k0_m_k1_grid_desc.GetLength(I2) &&
K1 == b_k0_n_k1_grid_desc.GetLength(I2)) && K1 == b_k0_n_k1_grid_desc.GetLength(I2)) &&
(M % MPerBlockM1 == 0 && N % NPerBlockN1 == 0 && K0 % KPerBlock == 0); (M % MPerBlockM1 == 0 && N % NPerBlockN1 == 0 && K0 % KPerBlock == 0);
} }
...@@ -231,13 +230,13 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 ...@@ -231,13 +230,13 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
const auto M1 = Number<MPerBlockM1>{}; const auto M1 = Number<MPerBlockM1>{};
const auto M0 = M / M1; const auto M0 = M / M1;
const auto a_k0_m0_m1_k1_grid_desc = transform_dynamic_tensor_descriptor( const auto a_k0_m0_m1_k1_grid_desc =
a_k0_m_k1_grid_desc, transform_tensor_descriptor(a_k0_m_k1_grid_desc,
make_tuple(make_pass_through_transform(K0), make_tuple(make_pass_through_transform(K0),
make_unmerge_transform(make_tuple(M0, M1)), make_unmerge_transform(make_tuple(M0, M1)),
make_pass_through_transform(K1)), make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
return a_k0_m0_m1_k1_grid_desc; return a_k0_m0_m1_k1_grid_desc;
} }
...@@ -251,13 +250,13 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 ...@@ -251,13 +250,13 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
const auto N1 = Number<NPerBlockN1>{}; const auto N1 = Number<NPerBlockN1>{};
const auto N0 = N / N1; const auto N0 = N / N1;
const auto b_k0_n0_n1_k1_grid_desc = transform_dynamic_tensor_descriptor( const auto b_k0_n0_n1_k1_grid_desc =
b_k0_n_k1_grid_desc, transform_tensor_descriptor(b_k0_n_k1_grid_desc,
make_tuple(make_pass_through_transform(K0), make_tuple(make_pass_through_transform(K0),
make_unmerge_transform(make_tuple(N0, N1)), make_unmerge_transform(make_tuple(N0, N1)),
make_pass_through_transform(K1)), make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
return b_k0_n0_n1_k1_grid_desc; return b_k0_n0_n1_k1_grid_desc;
} }
...@@ -284,7 +283,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 ...@@ -284,7 +283,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
constexpr auto M10 = M1 / M11; constexpr auto M10 = M1 / M11;
constexpr auto N10 = N1 / N11; constexpr auto N10 = N1 / N11;
const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_dynamic_tensor_descriptor( const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_tensor_descriptor(
c_m_n_grid_desc, c_m_n_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)), make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
make_unmerge_transform(make_tuple(N0, N10, N11))), make_unmerge_transform(make_tuple(N0, N10, N11))),
...@@ -355,23 +354,23 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 ...@@ -355,23 +354,23 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
// TODO: check alignment // TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto a_k0_m0_m1_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto a_k0_m0_m1_k1_block_desc = make_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, I1, Number<MPerBlockM1>{}, K1), max_lds_align); make_tuple(Number<KPerBlock>{}, I1, Number<MPerBlockM1>{}, K1), max_lds_align);
// TODO: check alignment // TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto b_k0_n0_n1_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto b_k0_n0_n1_k1_block_desc = make_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, I1, Number<NPerBlockN1>{}, K1), max_lds_align); make_tuple(Number<KPerBlock>{}, I1, Number<NPerBlockN1>{}, K1), max_lds_align);
// TODO: check alignment // TODO: check alignment
// A matrix in LDS memory, for blockwise GEMM // A matrix in LDS memory, for blockwise GEMM
constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}, K1), max_lds_align); make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}, K1), max_lds_align);
// TODO: check alignment // TODO: check alignment
// B matrix in LDS memory, for blockwise GEMM // B matrix in LDS memory, for blockwise GEMM
constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}, K1), max_lds_align); make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}, K1), max_lds_align);
static_assert(a_k0_m0_m1_k1_block_desc.GetElementSpaceSize() == static_assert(a_k0_m0_m1_k1_block_desc.GetElementSpaceSize() ==
...@@ -381,7 +380,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 ...@@ -381,7 +380,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
"wrong!"); "wrong!");
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1< auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1<
BlockSize, BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<KPerBlock, 1, MPerBlockM1, K1.value>, Sequence<KPerBlock, 1, MPerBlockM1, K1.value>,
...@@ -405,7 +404,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 ...@@ -405,7 +404,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
make_multi_index(0, 0, 0, 0)); make_multi_index(0, 0, 0, 0));
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1< auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1<
BlockSize, BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<KPerBlock, 1, NPerBlockN1, K1.value>, Sequence<KPerBlock, 1, NPerBlockN1, K1.value>,
...@@ -453,9 +452,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 ...@@ -453,9 +452,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1(); decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
constexpr auto c_m10_m11_n10_n11_thread_desc = constexpr auto c_m10_m11_n10_n11_thread_desc = make_naive_tensor_descriptor_packed(
make_dynamic_naive_tensor_descriptor_packed_v2( sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size = math::integer_least_multiple( constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
...@@ -471,9 +469,9 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 ...@@ -471,9 +469,9 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>( auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>(
c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize()); c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize());
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc, ThreadwiseTensorSliceSet_v1<FloatAcc,
decltype(c_m10_m11_n10_n11_thread_desc), decltype(c_m10_m11_n10_n11_thread_desc),
decltype(c_m10_m11_n10_n11_thread_tensor_lengths)>{} decltype(c_m10_m11_n10_n11_thread_tensor_lengths)>{}
.Run(c_m10_m11_n10_n11_thread_desc, .Run(c_m10_m11_n10_n11_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
c_thread_buf, c_thread_buf,
...@@ -496,8 +494,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 ...@@ -496,8 +494,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{}); a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{}); b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf); a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf); b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf);
...@@ -516,18 +514,16 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 ...@@ -516,18 +514,16 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
// even iteration // even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc,
a_block_slice_copy_step, a_block_slice_copy_step,
AGridMoveSliceWindowIteratorHacks{}); AGridMoveSliceWindowStepHacks{});
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc,
b_block_slice_copy_step, b_block_slice_copy_step,
BGridMoveSliceWindowIteratorHacks{}); BGridMoveSliceWindowStepHacks{});
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{}); b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
b_blockwise_copy.RunRead(
b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{});
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc, blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc,
...@@ -542,18 +538,16 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 ...@@ -542,18 +538,16 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
// odd iteration // odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc,
a_block_slice_copy_step, a_block_slice_copy_step,
AGridMoveSliceWindowIteratorHacks{}); AGridMoveSliceWindowStepHacks{});
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc,
b_block_slice_copy_step, b_block_slice_copy_step,
BGridMoveSliceWindowIteratorHacks{}); BGridMoveSliceWindowStepHacks{});
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{}); b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
b_blockwise_copy.RunRead(
b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{});
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run( blockwise_gemm.Run(
...@@ -570,18 +564,16 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 ...@@ -570,18 +564,16 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
// LDS double buffer: tail // LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{ {
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, a_blockwise_copy.MoveSrcSliceWindow(
a_block_slice_copy_step, a_k0_m0_m1_k1_grid_desc, a_block_slice_copy_step, AGridMoveSliceWindowStepHacks{});
AGridMoveSliceWindowIteratorHacks{}); b_blockwise_copy.MoveSrcSliceWindow(
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, b_k0_n0_n1_k1_grid_desc, b_block_slice_copy_step, BGridMoveSliceWindowStepHacks{});
b_block_slice_copy_step,
BGridMoveSliceWindowIteratorHacks{});
__syncthreads(); __syncthreads();
// LDS double buffer: load last data from device mem // LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{}); a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{}); b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run( blockwise_gemm.Run(
...@@ -608,21 +600,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 ...@@ -608,21 +600,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
// output: register to global memory // output: register to global memory
{ {
constexpr auto M11 =
Number<container_reduce(M11N11ThreadClusterM110Xs{}, math::multiplies_v2{}, I1) *
M1PerThreadM111>{};
constexpr auto N11 =
Number<container_reduce(M11N11ThreadClusterN110Xs{}, math::multiplies_v2{}, I1) *
N1PerThreadN111>{};
constexpr index_t M10 = MPerBlockM1 / M11;
constexpr index_t N10 = NPerBlockN1 / N11;
constexpr index_t M111 = M1PerThreadM111;
constexpr index_t N111 = N1PerThreadN111;
constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc = constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2( make_naive_tensor_descriptor_packed(
make_tuple(I1, make_tuple(I1,
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{}, Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I1]>{}, Number<c_m10_m11_n10_n11_thread_tensor_lengths[I1]>{},
...@@ -634,7 +613,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 ...@@ -634,7 +613,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
get_thread_local_1d_id()); get_thread_local_1d_id());
ThreadwiseDynamicTensorSliceTransfer_v1r3< ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc, FloatAcc,
FloatC, FloatC,
decltype(c_m0_m10_m11_n0_n10_n11_thread_desc), decltype(c_m0_m10_m11_n0_n10_n11_thread_desc),
...@@ -662,7 +641,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 ...@@ -662,7 +641,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
c_thread_buf, c_thread_buf,
c_m0_m10_m11_n0_n10_n11_grid_desc, c_m0_m10_m11_n0_n10_n11_grid_desc,
c_grid_buf, c_grid_buf,
CGridIteratorHacks{}); CGridStepHacks{});
} }
} }
}; };
......
#ifndef CK_GRIDWISE_DYNAMIC_GEMM_V2_HPP #ifndef CK_GRIDWISE_GEMM_V2_HPP
#define CK_GRIDWISE_DYNAMIC_GEMM_V2_HPP #define CK_GRIDWISE_GEMM_V2_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_multi_index_transform_helper.hpp" #include "multi_index_transform_helper.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "blockwise_dynamic_tensor_slice_transfer.hpp" #include "blockwise_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "blockwise_gemm_dlops_v3.hpp" #include "blockwise_gemm_dlops_v3.hpp"
namespace ck { namespace ck {
...@@ -42,12 +42,12 @@ template <index_t BlockSize, ...@@ -42,12 +42,12 @@ template <index_t BlockSize,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector, index_t CThreadTransferDstScalarPerVector,
typename AGlobalIteratorHacks, typename AGlobalStepHacks,
typename BGlobalIteratorHacks, typename BGlobalStepHacks,
typename CGlobalIteratorHacks, typename CGlobalStepHacks,
typename AGlobalMoveSliceWindowIteratorHacks, typename AGlobalMoveSliceWindowStepHacks,
typename BGlobalMoveSliceWindowIteratorHacks> typename BGlobalMoveSliceWindowStepHacks>
struct GridwiseDynamicGemmDlops_km_kn_mn_v3 struct GridwiseGemmDlops_km_kn_mn_v3
{ {
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
...@@ -58,7 +58,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3 ...@@ -58,7 +58,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto a_e_k_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto a_e_k_desc = make_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<E>{}, Number<KPerBlock>{}), max_lds_align); make_tuple(Number<E>{}, Number<KPerBlock>{}), max_lds_align);
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
...@@ -102,7 +102,6 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3 ...@@ -102,7 +102,6 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
// divide block work by [M, N] // divide block work by [M, N]
#if 0 #if 0
const auto k_block_work_num = K / Number<KPerBlock>{};
const auto ho_block_work_num = Ho / Number<HoPerBlock>{}; const auto ho_block_work_num = Ho / Number<HoPerBlock>{};
const auto wo_block_work_num = Wo / Number<WoPerBlock>{}; const auto wo_block_work_num = Wo / Number<WoPerBlock>{};
const auto hwo_block_work_num = ho_block_work_num * wo_block_work_num; const auto hwo_block_work_num = ho_block_work_num * wo_block_work_num;
...@@ -114,7 +113,6 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3 ...@@ -114,7 +113,6 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num; const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num;
#else #else
// Hack: this force result into SGPR // Hack: this force result into SGPR
const index_t k_block_work_num = __builtin_amdgcn_readfirstlane(K / KPerBlock);
const index_t ho_block_work_num = __builtin_amdgcn_readfirstlane(Ho / HoPerBlock); const index_t ho_block_work_num = __builtin_amdgcn_readfirstlane(Ho / HoPerBlock);
const index_t wo_block_work_num = __builtin_amdgcn_readfirstlane(Wo / WoPerBlock); const index_t wo_block_work_num = __builtin_amdgcn_readfirstlane(Wo / WoPerBlock);
const index_t hwo_block_work_num = ho_block_work_num * wo_block_work_num; const index_t hwo_block_work_num = ho_block_work_num * wo_block_work_num;
...@@ -134,23 +132,21 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3 ...@@ -134,23 +132,21 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto a_e_k_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto a_e_k_block_desc = make_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<EPerBlock>{}, Number<KPerBlock>{}), max_lds_align); make_tuple(Number<EPerBlock>{}, Number<KPerBlock>{}), max_lds_align);
constexpr auto a_e_k_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto a_e_k_desc = make_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<E>{}, Number<KPerBlock>{}), max_lds_align); make_tuple(Number<E>{}, Number<KPerBlock>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto b_e_n_ho_wo_block_desc = constexpr auto b_e_n_ho_wo_block_desc = make_naive_tensor_descriptor_packed(make_tuple(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( Number<EPerBlock>{}, Number<1>{}, Number<HoPerBlock>{}, Number<WoPerBlock>{}));
Number<EPerBlock>{}, Number<1>{}, Number<HoPerBlock>{}, Number<WoPerBlock>{}));
// c_thread_mtx definition: this is a mess // c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx // TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_k_n_ho_wo_thread_desc = constexpr auto c_k_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( Number<KPerThread>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
Number<KPerThread>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3<BlockSize, BlockwiseGemmDlops_km_kn_m0m1n0n1_v3<BlockSize,
...@@ -184,47 +180,46 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3 ...@@ -184,47 +180,46 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<E, KPerBlock>, Sequence<E, KPerBlock>,
ABlockTransferThreadSliceLengths_E_K, ABlockTransferThreadSliceLengths_E_K,
ABlockTransferThreadClusterLengths_E_K, ABlockTransferThreadClusterLengths_E_K,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(a_e_k_global_desc), decltype(a_e_k_global_desc),
decltype(a_e_k_desc), decltype(a_e_k_desc),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
Sequence<0, 1>, Sequence<0, 1>,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
1, 1,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K, ABlockTransferDstScalarPerVector_K,
1, 1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true>( true>(a_e_k_global_desc,
a_e_k_global_desc, make_multi_index(0, k_block_data_on_global),
make_multi_index(0, k_block_data_on_global), a_e_k_desc,
a_e_k_desc, make_multi_index(0, 0));
make_multi_index(0, 0));
constexpr auto b_e_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
constexpr auto b_e_n_ho_wo_thread_desc = Number<EPerBlock>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<EPerBlock>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{})); auto b_threadwise_transfer =
ThreadwiseTensorSliceTransfer_v2<FloatAB,
auto b_threadwise_transfer = ThreadwiseDynamicTensorSliceTransfer_v2< FloatAB,
FloatAB, decltype(b_e_n_ho_wo_global_desc),
FloatAB, decltype(b_e_n_ho_wo_thread_desc),
decltype(b_e_n_ho_wo_global_desc), Sequence<EPerBlock, 1, HoPerThread, WoPerThread>,
decltype(b_e_n_ho_wo_thread_desc), BBlockTransferSrcAccessOrder,
Sequence<EPerBlock, 1, HoPerThread, WoPerThread>, BBlockTransferSrcVectorDim,
BBlockTransferSrcAccessOrder, BBlockTransferSrcScalarPerVector,
BBlockTransferSrcVectorDim, 1,
BBlockTransferSrcScalarPerVector, true>(
1, b_e_n_ho_wo_global_desc,
true>(b_e_n_ho_wo_global_desc, make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global));
make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global));
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_shared_block, a_e_k_desc.GetElementSpaceSize()); p_shared_block, a_e_k_desc.GetElementSpaceSize());
...@@ -232,44 +227,45 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3 ...@@ -232,44 +227,45 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
// register allocation for output // register allocation for output
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum_t::Vgpr,
FloatAcc, FloatAcc,
c_k_n_ho_wo_thread_desc.GetElementSpaceSize()> c_k_n_ho_wo_thread_desc.GetElementSpaceSize(),
true>
c_thread_buf; c_thread_buf;
// initialize output thread tensor // initialize output thread tensor
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc, ThreadwiseTensorSliceSet_v1<FloatAcc,
decltype(c_k_n_ho_wo_thread_desc), decltype(c_k_n_ho_wo_thread_desc),
Sequence<KPerThread, 1, HoPerThread, WoPerThread>>{} Sequence<KPerThread, 1, HoPerThread, WoPerThread>>{}
.Run(c_k_n_ho_wo_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0}); .Run(c_k_n_ho_wo_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0); constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy // hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_e_k_global_iterator_hacks = AGlobalIteratorHacks{}; constexpr auto a_e_k_global_step_hacks = AGlobalStepHacks{};
constexpr auto b_e_n_ho_wo_global_iterator_hacks = BGlobalIteratorHacks{}; constexpr auto b_e_n_ho_wo_global_step_hacks = BGlobalStepHacks{};
// hack to control index calculation when move slice window for A and B matrix for // hack to control index calculation when move slice window for A and B matrix for
// threadwise copy // threadwise copy
constexpr auto a_e_k_global_move_slice_window_iterator_hack = constexpr auto a_e_k_global_move_slice_window_step_hack = AGlobalMoveSliceWindowStepHacks{};
AGlobalMoveSliceWindowIteratorHacks{}; constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack =
constexpr auto b_e_n_ho_wo_global_move_slice_window_iterator_hack = BGlobalMoveSliceWindowStepHacks{};
BGlobalMoveSliceWindowIteratorHacks{};
// double regsiter buffer for b // double regsiter buffer for b
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum_t::Vgpr,
FloatAB, FloatAB,
b_e_n_ho_wo_thread_desc.GetElementSpaceSize()> b_e_n_ho_wo_thread_desc.GetElementSpaceSize(),
true>
b_thread_even_buf, b_thread_odd_buf; b_thread_even_buf, b_thread_odd_buf;
// LDS double buffer: preload data // LDS double buffer: preload data
{ {
a_blockwise_copy.RunRead(a_e_k_global_desc, a_global_buf, a_e_k_global_iterator_hacks); a_blockwise_copy.RunRead(a_e_k_global_desc, a_global_buf, a_e_k_global_step_hacks);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
b_global_buf, b_global_buf,
b_e_n_ho_wo_thread_desc, b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_thread_even_buf, b_thread_even_buf,
b_e_n_ho_wo_global_iterator_hacks); b_e_n_ho_wo_global_step_hacks);
a_blockwise_copy.RunWrite(a_e_k_desc, a_block_buf); a_blockwise_copy.RunWrite(a_e_k_desc, a_block_buf);
} }
...@@ -293,7 +289,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3 ...@@ -293,7 +289,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
b_e_n_ho_wo_thread_desc, b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_thread_odd_buf, b_thread_odd_buf,
b_e_n_ho_wo_global_iterator_hacks); b_e_n_ho_wo_global_step_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
// TODO: @Zhang Jing: blockwise gemm should be able to move slice window // TODO: @Zhang Jing: blockwise gemm should be able to move slice window
...@@ -309,7 +305,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3 ...@@ -309,7 +305,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
b_e_n_ho_wo_thread_desc, b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_thread_even_buf, b_thread_even_buf,
b_e_n_ho_wo_global_iterator_hacks); b_e_n_ho_wo_global_step_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
...@@ -332,7 +328,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3 ...@@ -332,7 +328,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
b_e_n_ho_wo_thread_desc, b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_thread_odd_buf, b_thread_odd_buf,
b_e_n_ho_wo_global_iterator_hacks); b_e_n_ho_wo_global_step_hacks);
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
...@@ -351,23 +347,22 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3 ...@@ -351,23 +347,22 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
// output: register to global memory // output: register to global memory
{ {
// hack to control index calculation when iterating over c_k_n_ho_wo_global tensor // hack to control index calculation when iterating over c_k_n_ho_wo_global tensor
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{}; constexpr auto c_k_n_ho_wo_global_tensor_step_hacks = CGlobalStepHacks{};
const index_t k_thread_data_on_global = const index_t k_thread_data_on_global =
k_block_data_on_global + k_thread_id * KPerThread; k_block_data_on_global + k_thread_id * KPerThread;
ThreadwiseDynamicTensorSliceTransfer_v1r3< ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
FloatAcc, FloatC,
FloatC, decltype(c_k_n_ho_wo_thread_desc),
decltype(c_k_n_ho_wo_thread_desc), decltype(c_k_n_ho_wo_global_desc),
decltype(c_k_n_ho_wo_global_desc), Sequence<KPerThread, 1, HoPerThread, WoPerThread>,
Sequence<KPerThread, 1, HoPerThread, WoPerThread>, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim,
CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector,
CThreadTransferDstScalarPerVector, CGlobalMemoryDataOperation,
CGlobalMemoryDataOperation, 1,
1, true>(
true>(
c_k_n_ho_wo_global_desc, c_k_n_ho_wo_global_desc,
make_multi_index( make_multi_index(
k_thread_data_on_global, 0, ho_thread_data_on_global, wo_thread_data_on_global)) k_thread_data_on_global, 0, ho_thread_data_on_global, wo_thread_data_on_global))
...@@ -376,7 +371,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3 ...@@ -376,7 +371,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
c_thread_buf, c_thread_buf,
c_k_n_ho_wo_global_desc, c_k_n_ho_wo_global_desc,
c_global_buf, c_global_buf,
c_k_n_ho_wo_global_tensor_iterator_hacks); c_k_n_ho_wo_global_tensor_step_hacks);
} }
} }
......
#ifndef CK_GRIDWISE_DYNAMIC_GEMM_XDLOPS_V2R3_HPP #ifndef CK_GRIDWISE_GEMM_XDLOPS_V2R3_HPP
#define CK_GRIDWISE_DYNAMIC_GEMM_XDLOPS_V2R3_HPP #define CK_GRIDWISE_GEMM_XDLOPS_V2R3_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_multi_index_transform_helper.hpp" #include "multi_index_transform_helper.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp" #include "blockwise_gemm_xdlops.hpp"
#include "blockwise_dynamic_tensor_slice_transfer.hpp" #include "blockwise_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_set.hpp" #include "threadwise_tensor_slice_set.hpp"
namespace ck { namespace ck {
...@@ -24,13 +24,13 @@ __global__ void ...@@ -24,13 +24,13 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_dynamic_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid, kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const AK0MK1GridDesc a_k0_m_k1_grid_desc, const AK0MK1GridDesc a_k0_m_k1_grid_desc,
const BK0NK1GridDesc b_k0_n_k1_grid_desc, const BK0NK1GridDesc b_k0_n_k1_grid_desc,
const CM0M1M2NGridDesc c_m0_m1_m2_n_grid_desc, const CM0M1M2NGridDesc c_m0_m1_m2_n_grid_desc,
const CBlockClusterAdaptor c_block_cluster_adaptor) const CBlockClusterAdaptor c_block_cluster_adaptor)
{ {
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
...@@ -58,25 +58,25 @@ __global__ void ...@@ -58,25 +58,25 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_dynamic_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid, kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const void CONSTANT* p_a_k0_m_k1_grid_desc, const void CONSTANT* p_a_k0_m_k1_grid_desc,
const void CONSTANT* p_b_k0_n_k1_grid_desc, const void CONSTANT* p_b_k0_n_k1_grid_desc,
const void CONSTANT* p_c_m0_m1_m2_n_grid_desc, const void CONSTANT* p_c_m0_m1_m2_n_grid_desc,
const void CONSTANT* p_c_block_cluster_adaptor) const void CONSTANT* p_c_block_cluster_adaptor)
{ {
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
const auto a_k0_m_k1_grid_desc = const auto a_k0_m_k1_grid_desc = *reinterpret_cast<const AK0MK1GridDesc*>(
*reinterpret_cast<const AK0MK1GridDesc*>((const void*)p_a_k0_m_k1_grid_desc); cast_pointer_to_generic_address_space(p_a_k0_m_k1_grid_desc));
const auto b_k0_n_k1_grid_desc = const auto b_k0_n_k1_grid_desc = *reinterpret_cast<const BK0NK1GridDesc*>(
*reinterpret_cast<const BK0NK1GridDesc*>((const void*)p_b_k0_n_k1_grid_desc); cast_pointer_to_generic_address_space(p_b_k0_n_k1_grid_desc));
const auto c_m0_m1_m2_n_grid_desc = const auto c_m0_m1_m2_n_grid_desc = *reinterpret_cast<const CM0M1M2NGridDesc*>(
*reinterpret_cast<const CM0M1M2NGridDesc*>((const void*)p_c_m0_m1_m2_n_grid_desc); cast_pointer_to_generic_address_space(p_c_m0_m1_m2_n_grid_desc));
const auto c_block_cluster_adaptor = const auto c_block_cluster_adaptor = *reinterpret_cast<const CBlockClusterAdaptor*>(
*reinterpret_cast<const CBlockClusterAdaptor*>((const void*)p_c_block_cluster_adaptor); cast_pointer_to_generic_address_space(p_c_block_cluster_adaptor));
__shared__ FloatAB p_shared_block[shared_block_size]; __shared__ FloatAB p_shared_block[shared_block_size];
...@@ -126,13 +126,13 @@ template <index_t BlockSize, ...@@ -126,13 +126,13 @@ template <index_t BlockSize,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector, index_t CThreadTransferDstScalarPerVector,
typename AGridIteratorHacks, typename AGridStepHacks,
typename BGridIteratorHacks, typename BGridStepHacks,
typename CGridIteratorHacks, typename CGridStepHacks,
typename AGridMoveSliceWindowIteratorHacks, typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowIteratorHacks, typename BGridMoveSliceWindowStepHacks,
bool CAccessOrderMRepeatNRepeat> bool CAccessOrderMRepeatNRepeat>
struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -148,12 +148,12 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -148,12 +148,12 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align); make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align); make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
...@@ -203,9 +203,6 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -203,9 +203,6 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCM0M1M2NGridDescriptor(const CMNGridDesc& c_m_n_grid_desc) MakeCM0M1M2NGridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
{ {
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerWave, NPerWave, K1>{}; constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerWave, NPerWave, K1>{};
constexpr auto CLayout = xdlops_gemm.GetCLayout(); constexpr auto CLayout = xdlops_gemm.GetCLayout();
...@@ -217,10 +214,9 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -217,10 +214,9 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat); constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat);
constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat); constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat);
constexpr auto N0 = Number<CLayout.N1()>{};
constexpr auto N1 = Number<CLayout.N0()>{}; constexpr auto N1 = Number<CLayout.N0()>{};
const auto c_m0_m1_m2_n_grid_desc = transform_dynamic_tensor_descriptor( const auto c_m0_m1_m2_n_grid_desc = transform_tensor_descriptor(
c_m_n_grid_desc, c_m_n_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, M0, M1, M2)), make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, M0, M1, M2)),
make_unmerge_transform(make_tuple(NRepeat, NWaves, N1))), make_unmerge_transform(make_tuple(NRepeat, NWaves, N1))),
...@@ -269,11 +265,6 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -269,11 +265,6 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const CM0M1M2NGridDesc& c_m0_m1_m2_n_grid_desc, const CM0M1M2NGridDesc& c_m0_m1_m2_n_grid_desc,
const CBlockClusterAdaptor& c_block_cluster_adaptor) const CBlockClusterAdaptor& c_block_cluster_adaptor)
{ {
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize()); p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
...@@ -282,8 +273,6 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -282,8 +273,6 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
p_c_grid, c_m0_m1_m2_n_grid_desc.GetElementSpaceSize()); p_c_grid, c_m0_m1_m2_n_grid_desc.GetElementSpaceSize());
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
// divide block work by [M, N] // divide block work by [M, N]
const auto block_work_idx = const auto block_work_idx =
...@@ -301,67 +290,65 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -301,67 +290,65 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align); make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align); make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<KPerBlock, MPerBlock, K1>, Sequence<KPerBlock, MPerBlock, K1>,
ABlockTransferThreadSliceLengths_K0_M_K1, ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(a_k0_m_k1_grid_desc), decltype(a_k0_m_k1_grid_desc),
decltype(a_k0_m_k1_block_desc), decltype(a_k0_m_k1_block_desc),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>, Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
2, 2,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1, ABlockTransferDstScalarPerVector_K1,
1, 1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true>( true>(a_k0_m_k1_grid_desc,
a_k0_m_k1_grid_desc, make_multi_index(0, m_block_data_idx_on_grid, 0),
make_multi_index(0, m_block_data_idx_on_grid, 0), a_k0_m_k1_block_desc,
a_k0_m_k1_block_desc, make_multi_index(0, 0, 0));
make_multi_index(0, 0, 0));
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = auto b_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<KPerBlock, NPerBlock, K1>, Sequence<KPerBlock, NPerBlock, K1>,
BBlockTransferThreadSliceLengths_K0_N_K1, BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(b_k0_n_k1_grid_desc), decltype(b_k0_n_k1_grid_desc),
decltype(b_k0_n_k1_block_desc), decltype(b_k0_n_k1_block_desc),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>, Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
2, 2,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1, BBlockTransferDstScalarPerVector_K1,
1, 1,
1, 1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true>( true>(b_k0_n_k1_grid_desc,
b_k0_n_k1_grid_desc, make_multi_index(0, n_block_data_idx_on_grid, 0),
make_multi_index(0, n_block_data_idx_on_grid, 0), b_k0_n_k1_block_desc,
b_k0_n_k1_block_desc, make_multi_index(0, 0, 0));
make_multi_index(0, 0, 0));
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
...@@ -375,7 +362,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -375,7 +362,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
NPerBlock % (NPerWave * NRepeat) == 0, NPerBlock % (NPerWave * NRepeat) == 0,
"wrong!"); "wrong!");
constexpr auto a_k0_m0_m1_k1_block_desc = transform_dynamic_tensor_descriptor( constexpr auto a_k0_m0_m1_k1_block_desc = transform_tensor_descriptor(
a_k0_m_k1_block_desc, a_k0_m_k1_block_desc,
make_tuple(make_pass_through_transform(Number<KPerBlock>{}), make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform( make_unmerge_transform(
...@@ -384,7 +371,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -384,7 +371,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
constexpr auto b_k0_n0_n1_k1_block_desc = transform_dynamic_tensor_descriptor( constexpr auto b_k0_n0_n1_k1_block_desc = transform_tensor_descriptor(
b_k0_n_k1_block_desc, b_k0_n_k1_block_desc,
make_tuple(make_pass_through_transform(Number<KPerBlock>{}), make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform( make_unmerge_transform(
...@@ -410,21 +397,19 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -410,21 +397,19 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
static_assert(NumBlks == 1 && NumXdlops == 1, "K Reduction Mfma only"); static_assert(NumBlks == 1 && NumXdlops == 1, "K Reduction Mfma only");
constexpr auto c_mr_nr_blk_desc = make_dynamic_naive_tensor_descriptor_packed_v2( constexpr auto c_mr_nr_blk_desc =
make_tuple(Number<MRepeat>{}, Number<NRepeat>{})); make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum_t::Vgpr,
vector_type<FloatAcc, BlkSize>, vector_type<FloatAcc, BlkSize>,
c_mr_nr_blk_desc.GetElementSpaceSize()> c_mr_nr_blk_desc.GetElementSpaceSize(),
true>
c_thread_buf; c_thread_buf;
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size = constexpr auto a_block_space_size =
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size =
math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block = p_shared_block; FloatAB* p_a_block = p_shared_block;
FloatAB* p_b_block = p_shared_block + a_block_space_size; FloatAB* p_b_block = p_shared_block + a_block_space_size;
...@@ -432,15 +417,13 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -432,15 +417,13 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy // hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k0_m_k1_grid_iterator_hacks = AGridIteratorHacks{}; constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{};
constexpr auto b_k0_n_k1_grid_iterator_hacks = BGridIteratorHacks{}; constexpr auto b_k0_n_k1_grid_step_hacks = BGridStepHacks{};
// hack to control index calculation when move slice window for A and B matrix for // hack to control index calculation when move slice window for A and B matrix for
// threadwise copy // threadwise copy
constexpr auto a_k0_m_k1_grid_move_slice_window_iterator_hack = constexpr auto a_k0_m_k1_grid_move_slice_window_step_hack = AGridMoveSliceWindowStepHacks{};
AGridMoveSliceWindowIteratorHacks{}; constexpr auto b_k0_n_k1_grid_move_slice_window_step_hack = BGridMoveSliceWindowStepHacks{};
constexpr auto b_k0_n_k1_grid_move_slice_window_iterator_hack =
BGridMoveSliceWindowIteratorHacks{};
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize()); p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize());
...@@ -449,10 +432,8 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -449,10 +432,8 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// preload data into LDS // preload data into LDS
{ {
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks);
a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_iterator_hacks); b_blockwise_copy.RunRead(b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks);
b_blockwise_copy.RunRead(
b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_iterator_hacks);
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf); a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf); b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf);
...@@ -465,18 +446,16 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -465,18 +446,16 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
{ {
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc, a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc,
a_block_slice_copy_step, a_block_slice_copy_step,
a_k0_m_k1_grid_move_slice_window_iterator_hack); a_k0_m_k1_grid_move_slice_window_step_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_grid_desc, b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_grid_desc,
b_block_slice_copy_step, b_block_slice_copy_step,
b_k0_n_k1_grid_move_slice_window_iterator_hack); b_k0_n_k1_grid_move_slice_window_step_hack);
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks);
a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_iterator_hacks);
block_sync_lds(); block_sync_lds();
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks);
b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_iterator_hacks);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
...@@ -506,7 +485,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -506,7 +485,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr index_t N1 = CLayout.N0(); constexpr index_t N1 = CLayout.N0();
constexpr auto c_m0_m1_m2_n_thread_desc = constexpr auto c_m0_m1_m2_n_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(Number<MRepeat>{}, make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{}, Number<NRepeat>{},
Number<1>{}, Number<1>{},
Number<1>{}, Number<1>{},
...@@ -515,7 +494,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -515,7 +494,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
Number<M2>{}, Number<M2>{},
Number<1>{})); Number<1>{}));
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatC, c_m0_m1_m2_n_thread_desc.GetElementSpaceSize()> StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatC, c_m0_m1_m2_n_thread_desc.GetElementSpaceSize(), true>
c_blk_buf_; c_blk_buf_;
static_for<0, MRepeat, 1>{}([&](auto mr_i) { static_for<0, MRepeat, 1>{}([&](auto mr_i) {
...@@ -542,12 +521,12 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -542,12 +521,12 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const index_t n_thread_data_on_grid = const index_t n_thread_data_on_grid =
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
constexpr auto c_m0_m1_m2_n_grid_tensor_iterator_hacks = CGridIteratorHacks{}; constexpr auto c_m0_m1_m2_n_grid_tensor_step_hacks = CGridStepHacks{};
constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat); constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat);
constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat); constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat);
ThreadwiseDynamicTensorSliceTransfer_v1r3< ThreadwiseTensorSliceTransfer_v1r3<
FloatC, FloatC,
FloatC, FloatC,
decltype(c_m0_m1_m2_n_thread_desc), decltype(c_m0_m1_m2_n_thread_desc),
...@@ -573,7 +552,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -573,7 +552,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
c_blk_buf_, c_blk_buf_,
c_m0_m1_m2_n_grid_desc, c_m0_m1_m2_n_grid_desc,
c_grid_buf, c_grid_buf,
c_m0_m1_m2_n_grid_tensor_iterator_hacks); c_m0_m1_m2_n_grid_tensor_step_hacks);
} }
#else #else
{ {
...@@ -581,11 +560,8 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -581,11 +560,8 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr index_t M1 = CLayout.N1(); constexpr index_t M1 = CLayout.N1();
constexpr index_t M2 = CLayout.M0(); constexpr index_t M2 = CLayout.M0();
constexpr auto c_m0_m1_m2_n_thread_desc = constexpr auto c_m0_m1_m2_n_thread_desc = make_naive_tensor_descriptor_packed(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( make_tuple(I1, I1, I1, I1, Number<M0>{}, Number<1>{}, Number<M2>{}, Number<1>{}));
I1, I1, I1, I1, Number<M0>{}, Number<1>{}, Number<M2>{}, Number<1>{}));
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatC, BlkSize> c_blk_buf_;
// calculate origin of thread output tensor on global memory // calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index // blockwise GEMM c matrix starting index
...@@ -598,20 +574,20 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -598,20 +574,20 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const index_t n_thread_data_on_grid = const index_t n_thread_data_on_grid =
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
constexpr auto c_m0_m1_m2_n_grid_tensor_iterator_hacks = CGridIteratorHacks{}; constexpr auto c_m0_m1_m2_n_grid_tensor_step_hacks = CGridStepHacks{};
auto c_thread_copy = auto c_thread_copy =
ThreadwiseDynamicTensorSliceTransfer_v1r3<FloatC, ThreadwiseTensorSliceTransfer_v1r3<FloatC,
FloatC, FloatC,
decltype(c_m0_m1_m2_n_thread_desc), decltype(c_m0_m1_m2_n_thread_desc),
decltype(c_m0_m1_m2_n_grid_desc), decltype(c_m0_m1_m2_n_grid_desc),
Sequence<1, 1, 1, 1, M0, 1, M2, 1>, Sequence<1, 1, 1, 1, M0, 1, M2, 1>,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
1, 1,
true>{ true>{
c_m0_m1_m2_n_grid_desc, c_m0_m1_m2_n_grid_desc,
make_multi_index(0, make_multi_index(0,
0, 0,
...@@ -629,7 +605,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -629,7 +605,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_m1_m2_n_grid_desc, c_m0_m1_m2_n_grid_desc,
c_grid_buf, c_grid_buf,
c_m0_m1_m2_n_grid_tensor_iterator_hacks); c_m0_m1_m2_n_grid_tensor_step_hacks);
return c_thread_idx_; return c_thread_idx_;
}; };
...@@ -644,7 +620,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -644,7 +620,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_m1_m2_n_grid_desc, c_m0_m1_m2_n_grid_desc,
c_grid_buf, c_grid_buf,
c_m0_m1_m2_n_grid_tensor_iterator_hacks); c_m0_m1_m2_n_grid_tensor_step_hacks);
}; };
auto nrepeat_plus_copy = [&](auto c_thread_idx_) { auto nrepeat_plus_copy = [&](auto c_thread_idx_) {
...@@ -657,7 +633,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -657,7 +633,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_m1_m2_n_grid_desc, c_m0_m1_m2_n_grid_desc,
c_grid_buf, c_grid_buf,
c_m0_m1_m2_n_grid_tensor_iterator_hacks); c_m0_m1_m2_n_grid_tensor_step_hacks);
}; };
auto mrepeat_minus_copy = [&](auto c_thread_idx_) { auto mrepeat_minus_copy = [&](auto c_thread_idx_) {
...@@ -670,7 +646,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -670,7 +646,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_m1_m2_n_grid_desc, c_m0_m1_m2_n_grid_desc,
c_grid_buf, c_grid_buf,
c_m0_m1_m2_n_grid_tensor_iterator_hacks); c_m0_m1_m2_n_grid_tensor_step_hacks);
}; };
auto nrepeat_minus_copy = [&](auto c_thread_idx_) { auto nrepeat_minus_copy = [&](auto c_thread_idx_) {
...@@ -683,7 +659,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -683,7 +659,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_m1_m2_n_grid_desc, c_m0_m1_m2_n_grid_desc,
c_grid_buf, c_grid_buf,
c_m0_m1_m2_n_grid_tensor_iterator_hacks); c_m0_m1_m2_n_grid_tensor_step_hacks);
}; };
static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4 && NRepeat == 2) or static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4 && NRepeat == 2) or
......
...@@ -21,10 +21,10 @@ template <typename FloatA, ...@@ -21,10 +21,10 @@ template <typename FloatA,
typename TKLengths, typename TKLengths,
typename TMLengths, typename TMLengths,
typename TNLengths, typename TNLengths,
typename std::enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && typename enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(), CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1 struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1
{ {
__device__ constexpr ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1() __device__ constexpr ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1()
...@@ -97,10 +97,9 @@ struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1 ...@@ -97,10 +97,9 @@ struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1
CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset( CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset(
c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1)); c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1));
amd_inner_product_dlop<FloatA, FloatB, FloatC>( inner_product<FloatA, FloatB, FloatC>(a_buf[Number<a_offset>{}],
a_buf[Number<a_offset>{}], b_buf[Number<b_offset>{}],
b_buf[Number<b_offset>{}], c_buf(Number<c_offset>{}));
c_buf(Number<c_offset>{}));
}); });
}); });
}); });
...@@ -124,10 +123,10 @@ template <typename FloatA, ...@@ -124,10 +123,10 @@ template <typename FloatA,
typename TKLengths, typename TKLengths,
typename TMLengths, typename TMLengths,
typename TNLengths, typename TNLengths,
typename std::enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && typename enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(), CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1 struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
{ {
__device__ constexpr ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1() __device__ constexpr ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1()
...@@ -214,7 +213,7 @@ struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_ ...@@ -214,7 +213,7 @@ struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_
CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset( CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset(
c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1)); c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1));
amd_inner_product_dlop<a_vector_t, b_vector_t, FloatC>( inner_product<a_vector_t, b_vector_t, FloatC>(
a_vec.template AsType<a_vector_t>()[I0], a_vec.template AsType<a_vector_t>()[I0],
b_vec.template AsType<b_vector_t>()[I0], b_vec.template AsType<b_vector_t>()[I0],
c_buf(Number<c_offset>{})); c_buf(Number<c_offset>{}));
......
...@@ -19,9 +19,9 @@ template <typename FloatA, ...@@ -19,9 +19,9 @@ template <typename FloatA,
typename CDesc, typename CDesc,
index_t H, index_t H,
index_t W, index_t W,
typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && typename enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(), CDesc::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct ThreadwiseGemmDlops_km_kn_mn_v3 struct ThreadwiseGemmDlops_km_kn_mn_v3
{ {
template <typename ABuffer, template <typename ABuffer,
...@@ -57,8 +57,6 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3 ...@@ -57,8 +57,6 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto E = ADesc{}.GetLength(I0); constexpr auto E = ADesc{}.GetLength(I0);
constexpr auto K = ADesc{}.GetLength(I1); constexpr auto K = ADesc{}.GetLength(I1);
......
#ifndef CK_THREADWISE_DYNAMIC_TENSOR_SET_HPP #ifndef CK_THREADWISE_TENSOR_SET_HPP
#define CK_THREADWISE_DYNAMIC_TENSOR_SET_HPP #define CK_THREADWISE_TENSOR_SET_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
namespace ck { namespace ck {
...@@ -11,12 +11,12 @@ namespace ck { ...@@ -11,12 +11,12 @@ namespace ck {
// 1. Desc is known at compile-time // 1. Desc is known at compile-time
// 2. Buffer is StaticBuffer // 2. Buffer is StaticBuffer
// 3. OriginIdx is known at compile-time // 3. OriginIdx is known at compile-time
// 4. use #-iterator // 4. use #-step
template <typename Data, template <typename Data,
typename Desc, typename Desc,
typename SliceLengths, typename SliceLengths,
typename std::enable_if<Desc::IsKnownAtCompileTime(), bool>::type = false> typename enable_if<Desc::IsKnownAtCompileTime(), bool>::type = false>
struct ThreadwiseDynamicTensorSliceSet_v1 struct ThreadwiseTensorSliceSet_v1
{ {
static constexpr index_t nDim = SliceLengths::Size(); static constexpr index_t nDim = SliceLengths::Size();
...@@ -40,7 +40,7 @@ struct ThreadwiseDynamicTensorSliceSet_v1 ...@@ -40,7 +40,7 @@ struct ThreadwiseDynamicTensorSliceSet_v1
constexpr auto origin_idx = to_multi_index(OriginIdx{}); constexpr auto origin_idx = to_multi_index(OriginIdx{});
static_ford<SliceLengths>{}([&](auto access_idx) { static_ford<SliceLengths>{}([&](auto access_idx) {
constexpr auto coord = make_dynamic_tensor_coordinate(desc, origin_idx + access_idx); constexpr auto coord = make_tensor_coordinate(desc, origin_idx + access_idx);
constexpr bool is_valid = constexpr bool is_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(desc, coord); coordinate_has_valid_offset_assuming_visible_index_is_valid(desc, coord);
......
#ifndef CK_THREADWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_HPP #ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_HPP
#define CK_THREADWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_HPP #define CK_THREADWISE_TENSOR_SLICE_TRANSFER_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
namespace ck { namespace ck {
...@@ -57,20 +57,20 @@ template <typename SrcData, ...@@ -57,20 +57,20 @@ template <typename SrcData,
InMemoryDataOperationEnum_t DstInMemOp, InMemoryDataOperationEnum_t DstInMemOp,
index_t DstScalarStrideInVector, index_t DstScalarStrideInVector,
bool DstResetCoordinateAfterRun, bool DstResetCoordinateAfterRun,
typename std::enable_if<SrcDesc::IsKnownAtCompileTime(), bool>::type = false> typename enable_if<SrcDesc::IsKnownAtCompileTime(), bool>::type = false>
struct ThreadwiseDynamicTensorSliceTransfer_v1r3 struct ThreadwiseTensorSliceTransfer_v1r3
{ {
static constexpr index_t nDim = SliceLengths::Size(); static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
using DstCoord = decltype(make_dynamic_tensor_coordinate(DstDesc{}, Index{})); using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
using DstCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(DstDesc{}, Index{})); using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v1r3( __device__ constexpr ThreadwiseTensorSliceTransfer_v1r3(const DstDesc& dst_desc,
const DstDesc& dst_desc, const Index& dst_slice_origin_idx) const Index& dst_slice_origin_idx)
: dst_coord_(make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx)) : dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx))
{ {
static_assert(SrcDesc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time"); "wrong! SrcDesc need to known at compile-time");
...@@ -78,19 +78,19 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -78,19 +78,19 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
{ {
dst_coord_ = make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx); dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
} }
template <typename SrcSliceOriginIdx, template <typename SrcSliceOriginIdx,
typename SrcBuffer, typename SrcBuffer,
typename DstBuffer, typename DstBuffer,
typename DstIteratorHacks> typename DstStepHacks>
__device__ void Run(const SrcDesc&, __device__ void Run(const SrcDesc&,
const SrcSliceOriginIdx&, const SrcSliceOriginIdx&,
const SrcBuffer& src_buf, const SrcBuffer& src_buf,
const DstDesc& dst_desc, const DstDesc& dst_desc,
DstBuffer& dst_buf, DstBuffer& dst_buf,
const DstIteratorHacks& dst_iterator_hacks) const DstStepHacks& dst_step_hacks)
{ {
static_assert(SrcDesc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time"); "wrong! SrcDesc need to known at compile-time");
...@@ -127,31 +127,31 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -127,31 +127,31 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
constexpr auto ordered_access_lengths = constexpr auto ordered_access_lengths =
container_reorder_given_new2old(access_lengths, dim_access_order); container_reorder_given_new2old(access_lengths, dim_access_order);
// make forward iterators // make forward steps
const auto dst_forward_iterators = generate_tuple( const auto dst_forward_steps = generate_tuple(
[&](auto i) { [&](auto i) {
Index forward_step; Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) { static_for<0, nDim, 1>{}([&](auto j) {
forward_step(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
}); });
return make_dynamic_tensor_coordinate_iterator( return make_tensor_coordinate_step(
dst_desc, forward_step, dst_iterator_hacks[I0][i]); dst_desc, forward_step_idx, dst_step_hacks[I0][i]);
}, },
Number<nDim>{}); Number<nDim>{});
// make backward iterators // make backward steps
const auto dst_backward_iterators = generate_tuple( const auto dst_backward_steps = generate_tuple(
[&](auto i) { [&](auto i) {
Index backward_step; Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) { static_for<0, nDim, 1>{}([&](auto j) {
backward_step(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
}); });
return make_dynamic_tensor_coordinate_iterator( return make_tensor_coordinate_step(
dst_desc, backward_step, dst_iterator_hacks[I1][i]); dst_desc, backward_step_idx, dst_step_hacks[I1][i]);
}, },
Number<nDim>{}); Number<nDim>{});
...@@ -235,13 +235,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -235,13 +235,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
{ {
if constexpr(forward_sweep[i]) if constexpr(forward_sweep[i])
{ {
move_dynamic_tensor_coordinate( move_tensor_coordinate(
dst_desc, dst_coord_, dst_forward_iterators[dim_access_order[i]]); dst_desc, dst_coord_, dst_forward_steps[dim_access_order[i]]);
} }
else else
{ {
move_dynamic_tensor_coordinate( move_tensor_coordinate(
dst_desc, dst_coord_, dst_backward_iterators[dim_access_order[i]]); dst_desc, dst_coord_, dst_backward_steps[dim_access_order[i]]);
} }
} }
}); });
...@@ -250,10 +250,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -250,10 +250,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
// move dst coordinate back to slice origin (or not) // move dst coordinate back to slice origin (or not)
if constexpr(DstResetCoordinateAfterRun) if constexpr(DstResetCoordinateAfterRun)
{ {
const auto dst_reset_iterator = const auto dst_reset_step =
make_dynamic_tensor_coordinate_iterator(dst_desc, GetDstCoordinateResetStep()); make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep());
move_dynamic_tensor_coordinate(dst_desc, dst_coord_, dst_reset_iterator); move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
} }
} }
...@@ -268,11 +268,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -268,11 +268,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{}; constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
constexpr auto dst_iterator_hacks = constexpr auto dst_step_hacks =
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}), make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
generate_tuple([&](auto) { return zeros; }, Number<nDim>{})); generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
Run(SrcDesc{}, SrcSliceOriginIdx{}, src_buf, dst_desc, dst_buf, dst_iterator_hacks); Run(SrcDesc{}, SrcSliceOriginIdx{}, src_buf, dst_desc, dst_buf, dst_step_hacks);
} }
__device__ static constexpr auto GetDstCoordinateResetStep() __device__ static constexpr auto GetDstCoordinateResetStep()
...@@ -345,10 +345,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -345,10 +345,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
: dst_slice_origin_step_idx + GetDstCoordinateResetStep(); : dst_slice_origin_step_idx + GetDstCoordinateResetStep();
// is it OK to construct a new step every time? // is it OK to construct a new step every time?
const auto adjusted_step = const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
make_dynamic_tensor_coordinate_iterator(dst_desc, adjusted_step_idx);
move_dynamic_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
} }
private: private:
...@@ -374,20 +373,20 @@ template <typename SrcData, ...@@ -374,20 +373,20 @@ template <typename SrcData,
index_t SrcScalarPerVector, index_t SrcScalarPerVector,
index_t SrcScalarStrideInVector, index_t SrcScalarStrideInVector,
bool SrcResetCoordinateAfterRun, bool SrcResetCoordinateAfterRun,
typename std::enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false> typename enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false>
struct ThreadwiseDynamicTensorSliceTransfer_v2 struct ThreadwiseTensorSliceTransfer_v2
{ {
static constexpr index_t nDim = SliceLengths::Size(); static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
using SrcCoord = decltype(make_dynamic_tensor_coordinate(SrcDesc{}, Index{})); using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using SrcCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(SrcDesc{}, Index{})); using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v2(const SrcDesc& src_desc, __device__ constexpr ThreadwiseTensorSliceTransfer_v2(const SrcDesc& src_desc,
const Index& src_slice_origin_idx) const Index& src_slice_origin_idx)
: src_coord_(make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx)) : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin_idx))
{ {
static_assert(DstDesc::IsKnownAtCompileTime(), static_assert(DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time"); "wrong! SrcDesc need to known at compile-time");
...@@ -395,19 +394,19 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -395,19 +394,19 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
__device__ void SetDstSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) __device__ void SetDstSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
{ {
src_coord_ = make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx); src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
} }
template <typename SrcBuffer, template <typename SrcBuffer,
typename DstBuffer, typename DstBuffer,
typename DstSliceOriginIdx, typename DstSliceOriginIdx,
typename SrcIteratorHacks> typename SrcStepHacks>
__device__ void Run(const SrcDesc& src_desc, __device__ void Run(const SrcDesc& src_desc,
const SrcBuffer& src_buf, const SrcBuffer& src_buf,
const DstDesc&, const DstDesc&,
const DstSliceOriginIdx&, const DstSliceOriginIdx&,
DstBuffer& dst_buf, DstBuffer& dst_buf,
const SrcIteratorHacks& src_iterator_hacks) const SrcStepHacks& src_step_hacks)
{ {
static_assert(DstDesc::IsKnownAtCompileTime(), static_assert(DstDesc::IsKnownAtCompileTime(),
"wrong! DstDesc need to known at compile-time"); "wrong! DstDesc need to known at compile-time");
...@@ -442,31 +441,31 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -442,31 +441,31 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
constexpr auto ordered_access_lengths = constexpr auto ordered_access_lengths =
container_reorder_given_new2old(access_lengths, dim_access_order); container_reorder_given_new2old(access_lengths, dim_access_order);
// make forward iterators // make forward steps
const auto src_forward_iterators = generate_tuple( const auto src_forward_steps = generate_tuple(
[&](auto i) { [&](auto i) {
Index forward_step; Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) { static_for<0, nDim, 1>{}([&](auto j) {
forward_step(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
}); });
return make_dynamic_tensor_coordinate_iterator( return make_tensor_coordinate_step(
src_desc, forward_step, src_iterator_hacks[I0][i]); src_desc, forward_step_idx, src_step_hacks[I0][i]);
}, },
Number<nDim>{}); Number<nDim>{});
// make backward iterators // make backward steps
const auto src_backward_iterators = generate_tuple( const auto src_backward_steps = generate_tuple(
[&](auto i) { [&](auto i) {
Index backward_step; Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) { static_for<0, nDim, 1>{}([&](auto j) {
backward_step(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
}); });
return make_dynamic_tensor_coordinate_iterator( return make_tensor_coordinate_step(
src_desc, backward_step, src_iterator_hacks[I1][i]); src_desc, backward_step_idx, src_step_hacks[I1][i]);
}, },
Number<nDim>{}); Number<nDim>{});
...@@ -548,13 +547,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -548,13 +547,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
{ {
if constexpr(forward_sweep[i]) if constexpr(forward_sweep[i])
{ {
move_dynamic_tensor_coordinate( move_tensor_coordinate(
src_desc, src_coord_, src_forward_iterators[dim_access_order[i]]); src_desc, src_coord_, src_forward_steps[dim_access_order[i]]);
} }
else else
{ {
move_dynamic_tensor_coordinate( move_tensor_coordinate(
src_desc, src_coord_, src_backward_iterators[dim_access_order[i]]); src_desc, src_coord_, src_backward_steps[dim_access_order[i]]);
} }
} }
}); });
...@@ -563,10 +562,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -563,10 +562,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
// move src coordinate back to slice origin (or not) // move src coordinate back to slice origin (or not)
if constexpr(SrcResetCoordinateAfterRun) if constexpr(SrcResetCoordinateAfterRun)
{ {
const auto src_reset_iterator = const auto src_reset_step =
make_dynamic_tensor_coordinate_iterator(src_desc, GetSrcCoordinateResetStep()); make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep());
move_dynamic_tensor_coordinate(src_desc, src_coord_, src_reset_iterator); move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
} }
} }
...@@ -581,11 +580,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -581,11 +580,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{}; constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{};
constexpr auto src_iterator_hacks = constexpr auto src_step_hacks =
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}), make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
generate_tuple([&](auto) { return zeros; }, Number<nDim>{})); generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
Run(src_desc, src_buf, DstDesc{}, DstSliceOriginIdx{}, dst_buf, src_iterator_hacks); Run(src_desc, src_buf, DstDesc{}, DstSliceOriginIdx{}, dst_buf, src_step_hacks);
} }
__device__ static constexpr auto GetSrcCoordinateResetStep() __device__ static constexpr auto GetSrcCoordinateResetStep()
...@@ -658,10 +657,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -658,10 +657,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
: src_slice_origin_step_idx + GetSrcCoordinateResetStep(); : src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time? // is it OK to construct a new step every time?
const auto adjusted_step = const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
make_dynamic_tensor_coordinate_iterator(src_desc, adjusted_step_idx);
move_dynamic_tensor_coordinate(src_desc, src_coord_, adjusted_step); move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
} }
private: private:
...@@ -693,23 +691,23 @@ template <typename SliceLengths, ...@@ -693,23 +691,23 @@ template <typename SliceLengths,
bool DstResetCoordinateAfterRun> // control whether to move back dst coordinate after each bool DstResetCoordinateAfterRun> // control whether to move back dst coordinate after each
// RunWrite(), will be fused with MoveDstSliceWindow to // RunWrite(), will be fused with MoveDstSliceWindow to
// save addr computation // save addr computation
struct ThreadwiseDynamicTensorSliceTransfer_v3 struct ThreadwiseTensorSliceTransfer_v3
{ {
static constexpr index_t nDim = SliceLengths::Size(); static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
using SrcCoord = decltype(make_dynamic_tensor_coordinate(SrcDesc{}, Index{})); using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using DstCoord = decltype(make_dynamic_tensor_coordinate(DstDesc{}, Index{})); using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
using SrcCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(SrcDesc{}, Index{})); using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
using DstCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(DstDesc{}, Index{})); using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v3(const SrcDesc& src_desc, __device__ constexpr ThreadwiseTensorSliceTransfer_v3(const SrcDesc& src_desc,
const Index& src_slice_origin, const Index& src_slice_origin,
const DstDesc& dst_desc, const DstDesc& dst_desc,
const Index& dst_slice_origin) const Index& dst_slice_origin)
: src_coord_(make_dynamic_tensor_coordinate(src_desc, src_slice_origin)), : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
dst_coord_(make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin)) dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin))
{ {
// TODO: fix this // TODO: fix this
static_assert(is_same<SrcData, DstData>::value, static_assert(is_same<SrcData, DstData>::value,
...@@ -718,18 +716,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -718,18 +716,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
{ {
src_coord_ = make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx); src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
} }
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
{ {
dst_coord_ = make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx); dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
} }
template <typename SrcBuffer, typename SrcIteratorHacks> template <typename SrcBuffer, typename SrcStepHacks>
__device__ void RunRead(const SrcDesc& src_desc, __device__ void
const SrcBuffer& src_buf, RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
const SrcIteratorHacks& src_iterator_hacks)
{ {
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
...@@ -757,31 +754,31 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -757,31 +754,31 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
constexpr auto ordered_src_access_lengths = constexpr auto ordered_src_access_lengths =
container_reorder_given_new2old(src_access_lengths, src_dim_access_order); container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
// make forward iterators // make forward steps
const auto src_forward_iterators = generate_tuple( const auto src_forward_steps = generate_tuple(
[&](auto i) { [&](auto i) {
Index forward_step; Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) { static_for<0, nDim, 1>{}([&](auto j) {
forward_step(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
}); });
return make_dynamic_tensor_coordinate_iterator( return make_tensor_coordinate_step(
src_desc, forward_step, src_iterator_hacks[I0][i]); src_desc, forward_step_idx, src_step_hacks[I0][i]);
}, },
Number<nDim>{}); Number<nDim>{});
// make backward iterators // make backward steps
const auto src_backward_iterators = generate_tuple( const auto src_backward_steps = generate_tuple(
[&](auto i) { [&](auto i) {
Index backward_step; Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) { static_for<0, nDim, 1>{}([&](auto j) {
backward_step(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
}); });
return make_dynamic_tensor_coordinate_iterator( return make_tensor_coordinate_step(
src_desc, backward_step, src_iterator_hacks[I1][i]); src_desc, backward_step_idx, src_step_hacks[I1][i]);
}, },
Number<nDim>{}); Number<nDim>{});
...@@ -862,13 +859,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -862,13 +859,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
{ {
if constexpr(forward_sweep[i]) if constexpr(forward_sweep[i])
{ {
move_dynamic_tensor_coordinate( move_tensor_coordinate(
src_desc, src_coord_, src_forward_iterators[src_dim_access_order[i]]); src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]);
} }
else else
{ {
move_dynamic_tensor_coordinate( move_tensor_coordinate(
src_desc, src_coord_, src_backward_iterators[src_dim_access_order[i]]); src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]);
} }
} }
}); });
...@@ -877,17 +874,16 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -877,17 +874,16 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
// move src coordinate back to slice origin (or not) // move src coordinate back to slice origin (or not)
if constexpr(SrcResetCoordinateAfterRun) if constexpr(SrcResetCoordinateAfterRun)
{ {
const auto src_reset_iterator = const auto src_reset_step =
make_dynamic_tensor_coordinate_iterator(src_desc, GetSrcCoordinateResetStep()); make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep());
move_dynamic_tensor_coordinate(src_desc, src_coord_, src_reset_iterator); move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
} }
} }
template <typename DstBuffer, typename DstIteratorHacks> template <typename DstBuffer, typename DstStepHacks>
__device__ void RunWrite(const DstDesc& dst_desc, __device__ void
DstBuffer& dst_buf, RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf, const DstStepHacks& dst_step_hacks)
const DstIteratorHacks& dst_iterator_hacks)
{ {
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
...@@ -915,35 +911,31 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -915,35 +911,31 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
constexpr auto ordered_dst_access_lengths = constexpr auto ordered_dst_access_lengths =
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
// make forward iterators // make forward steps
const auto dst_forward_iterators = generate_tuple( const auto dst_forward_steps = generate_tuple(
[&](auto i) { [&](auto i) {
Index forward_step; Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) { static_for<0, nDim, 1>{}([&](auto j) {
forward_step(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
}); });
const auto forward_iterator = make_dynamic_tensor_coordinate_iterator( return make_tensor_coordinate_step(
dst_desc, forward_step, dst_iterator_hacks[I0][i]); dst_desc, forward_step_idx, dst_step_hacks[I0][i]);
return forward_iterator;
}, },
Number<nDim>{}); Number<nDim>{});
// make backward iterators // make backward steps
const auto dst_backward_iterators = generate_tuple( const auto dst_backward_steps = generate_tuple(
[&](auto i) { [&](auto i) {
Index backward_step; Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) { static_for<0, nDim, 1>{}([&](auto j) {
backward_step(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
}); });
const auto backward_iterator = make_dynamic_tensor_coordinate_iterator( return make_tensor_coordinate_step(
dst_desc, backward_step, dst_iterator_hacks[I1][i]); dst_desc, backward_step_idx, dst_step_hacks[I1][i]);
return backward_iterator;
}, },
Number<nDim>{}); Number<nDim>{});
...@@ -1026,13 +1018,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -1026,13 +1018,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
{ {
if constexpr(forward_sweep[i]) if constexpr(forward_sweep[i])
{ {
move_dynamic_tensor_coordinate( move_tensor_coordinate(
dst_desc, dst_coord_, dst_forward_iterators[dst_dim_access_order[i]]); dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]);
} }
else else
{ {
move_dynamic_tensor_coordinate( move_tensor_coordinate(
dst_desc, dst_coord_, dst_backward_iterators[dst_dim_access_order[i]]); dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]);
} }
} }
}); });
...@@ -1041,10 +1033,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -1041,10 +1033,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
// move dst coordinate back to slice origin (or not) // move dst coordinate back to slice origin (or not)
if constexpr(DstResetCoordinateAfterRun) if constexpr(DstResetCoordinateAfterRun)
{ {
const auto dst_reset_iterator = const auto dst_reset_step =
make_dynamic_tensor_coordinate_iterator(dst_desc, GetDstCoordinateResetStep()); make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep());
move_dynamic_tensor_coordinate(dst_desc, dst_coord_, dst_reset_iterator); move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
} }
} }
...@@ -1055,11 +1047,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -1055,11 +1047,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{}; constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{};
constexpr auto src_iterator_hacks = constexpr auto src_step_hacks =
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}), make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
generate_tuple([&](auto) { return zeros; }, Number<nDim>{})); generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
RunRead(src_desc, src_buf, src_iterator_hacks); RunRead(src_desc, src_buf, src_step_hacks);
} }
template <typename DstBuffer> template <typename DstBuffer>
...@@ -1069,11 +1061,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -1069,11 +1061,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{}; constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
constexpr auto dst_iterator_hacks = constexpr auto dst_step_hacks =
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}), make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
generate_tuple([&](auto) { return zeros; }, Number<nDim>{})); generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
RunWrite(dst_desc, dst_buf, dst_iterator_hacks); RunWrite(dst_desc, dst_buf, dst_step_hacks);
} }
__device__ static constexpr auto GetSrcCoordinateResetStep() __device__ static constexpr auto GetSrcCoordinateResetStep()
...@@ -1206,18 +1198,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -1206,18 +1198,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
: src_slice_origin_step_idx + GetSrcCoordinateResetStep(); : src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time? // is it OK to construct a new step every time?
const auto adjusted_step = const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
make_dynamic_tensor_coordinate_iterator(src_desc, adjusted_step_idx);
move_dynamic_tensor_coordinate(src_desc, src_coord_, adjusted_step); move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
} }
// src_slice_origin_step_idx need to be known at compile-time, for performance reason // src_slice_origin_step_idx need to be known at compile-time, for performance reason
template <typename SrcMoveSliceWindowIteratorHack> template <typename SrcMoveSliceWindowStepHack>
__device__ void __device__ void
MoveSrcSliceWindow(const SrcDesc& src_desc, MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& src_slice_origin_step_idx, const Index& src_slice_origin_step_idx,
const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack) const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
{ {
// if src coord was not reset by RunRead(), then need to adjust the step here // if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx = const auto adjusted_step_idx =
...@@ -1225,10 +1216,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -1225,10 +1216,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
: src_slice_origin_step_idx + GetSrcCoordinateResetStep(); : src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time? // is it OK to construct a new step every time?
const auto adjusted_step = make_dynamic_tensor_coordinate_iterator( const auto adjusted_step = make_tensor_coordinate_step(
src_desc, adjusted_step_idx, src_move_slice_window_iterator_hack); src_desc, adjusted_step_idx, src_move_slice_window_step_hack);
move_dynamic_tensor_coordinate(src_desc, src_coord_, adjusted_step); move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
} }
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason // dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, __device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
...@@ -1240,19 +1231,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -1240,19 +1231,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
: dst_slice_origin_step_idx + GetDstCoordinateResetStep(); : dst_slice_origin_step_idx + GetDstCoordinateResetStep();
// is it OK to construct a new step every time? // is it OK to construct a new step every time?
const auto adjusted_step = const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
make_dynamic_tensor_coordinate_iterator(dst_desc, adjusted_step_idx);
move_dynamic_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
} }
private: private:
static constexpr auto buffer_desc_ = static constexpr auto buffer_desc_ =
make_dynamic_naive_tensor_descriptor_packed_v2(sequence_to_tuple_of_number(SliceLengths{})); make_naive_tensor_descriptor_packed(sequence_to_tuple_of_number(SliceLengths{}));
static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize(); static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize();
StaticBuffer<AddressSpaceEnum_t::Vgpr, SrcData, buffer_size_> buffer_; StaticBuffer<AddressSpaceEnum_t::Vgpr, SrcData, buffer_size_, true> buffer_;
SrcCoord src_coord_; SrcCoord src_coord_;
DstCoord dst_coord_; DstCoord dst_coord_;
...@@ -1264,37 +1254,36 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -1264,37 +1254,36 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
// 2. SrcBuffer is DynamicBuffer // 2. SrcBuffer is DynamicBuffer
// 3. src_ref_idx is known at run-time // 3. src_ref_idx is known at run-time
// 4. SrcRefToOriginDisplacement is known at compile-time // 4. SrcRefToOriginDisplacement is known at compile-time
// 5. use #-iterator // 5. use #-step
// 2. dst: // 2. dst:
// 1. DstDesc is known at compile-time // 1. DstDesc is known at compile-time
// 2. DstBuffer is StaticBuffer // 2. DstBuffer is StaticBuffer
// 3. DstOriginIdx is known at compile-time // 3. DstOriginIdx is known at compile-time
// 4. use direct address calculation // 4. use direct address calculation
// 3. vector access on src // 3. vector access on src
template < template <typename SrcData,
typename SrcData, typename DstData,
typename DstData, typename SrcDesc,
typename SrcDesc, typename DstDesc,
typename DstDesc, typename SliceLengths,
typename SliceLengths, typename DimAccessOrder,
typename DimAccessOrder, index_t SrcVectorDim,
index_t SrcVectorDim, index_t SrcScalarPerVector,
index_t SrcScalarPerVector, index_t SrcScalarStrideInVector,
index_t SrcScalarStrideInVector, typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
typename std::enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), bool>::type = false>
bool>::type = false> struct ThreadwiseTensorSliceTransfer_v4
struct ThreadwiseDynamicTensorSliceTransfer_v4
{ {
static constexpr index_t nDim = SliceLengths::Size(); static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
using SrcCoord = decltype(make_dynamic_tensor_coordinate(SrcDesc{}, Index{})); using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using SrcCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(SrcDesc{}, Index{})); using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v4(const Index& src_ref_idx) __device__ constexpr ThreadwiseTensorSliceTransfer_v4(const Index& src_ref_idx)
: src_ref_coord_(make_dynamic_tensor_coordinate(SrcDesc{}, src_ref_idx)) : src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx))
{ {
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc and DstDesc need to known at compile-time"); "wrong! SrcDesc and DstDesc need to known at compile-time");
...@@ -1390,13 +1379,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4 ...@@ -1390,13 +1379,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
constexpr auto src_ref_to_data_disp_idx = constexpr auto src_ref_to_data_disp_idx =
src_ref_to_origin_disp_idx + data_to_origin_disp_idx; src_ref_to_origin_disp_idx + data_to_origin_disp_idx;
constexpr auto src_ref_to_data_disp_coord_iterator = constexpr auto src_ref_to_data_disp_coord_step =
make_dynamic_tensor_coordinate_iterator(src_desc, src_ref_to_data_disp_idx); make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx);
auto src_data_coord = src_ref_coord_; auto src_data_coord = src_ref_coord_;
move_dynamic_tensor_coordinate( move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step);
src_desc, src_data_coord, src_ref_to_data_disp_coord_iterator);
vector_type_maker_t<SrcData, SrcScalarPerVector> src_tmp_vector; vector_type_maker_t<SrcData, SrcScalarPerVector> src_tmp_vector;
...@@ -1435,10 +1423,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4 ...@@ -1435,10 +1423,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
{ {
constexpr auto src_desc = SrcDesc{}; constexpr auto src_desc = SrcDesc{};
const auto src_slice_move_step_iter = make_dynamic_tensor_coordinate_iterator( const auto src_slice_move_step_iter =
src_desc, to_multi_index(src_slice_move_step_idx)); make_tensor_coordinate_step(src_desc, to_multi_index(src_slice_move_step_idx));
move_dynamic_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter); move_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter);
} }
private: private:
......
#ifndef CK_THREADWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_V2_HPP #ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V2_HPP
#define CK_THREADWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_V2_HPP #define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V2_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
namespace ck { namespace ck {
...@@ -30,7 +30,7 @@ template <typename SliceLengths, ...@@ -30,7 +30,7 @@ template <typename SliceLengths,
bool DstResetCoordinateAfterRun> // control whether to move back dst coordinate after each bool DstResetCoordinateAfterRun> // control whether to move back dst coordinate after each
// RunWrite(), will be fused with MoveDstSliceWindow to // RunWrite(), will be fused with MoveDstSliceWindow to
// save addr computation // save addr computation
struct ThreadwiseDynamicTensorSliceTransfer_v3r1 struct ThreadwiseTensorSliceTransfer_v3r1
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -38,18 +38,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 ...@@ -38,18 +38,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
static constexpr index_t nDim = SliceLengths::Size(); static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
using SrcCoord = decltype(make_dynamic_tensor_coordinate(SrcDesc{}, Index{})); using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using DstCoord = decltype(make_dynamic_tensor_coordinate(DstDesc{}, Index{})); using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
using SrcCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(SrcDesc{}, Index{})); using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
using DstCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(DstDesc{}, Index{})); using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v3r1(const SrcDesc& src_desc, __device__ constexpr ThreadwiseTensorSliceTransfer_v3r1(const SrcDesc& src_desc,
const Index& src_slice_origin, const Index& src_slice_origin,
const DstDesc& dst_desc, const DstDesc& dst_desc,
const Index& dst_slice_origin) const Index& dst_slice_origin)
: src_coord_(make_dynamic_tensor_coordinate(src_desc, src_slice_origin)), : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
dst_coord_(make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin)) dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin))
{ {
// TODO: fix this // TODO: fix this
static_assert(is_same<SrcData, DstData>::value, static_assert(is_same<SrcData, DstData>::value,
...@@ -64,18 +64,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 ...@@ -64,18 +64,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
{ {
src_coord_ = make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx); src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
} }
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
{ {
dst_coord_ = make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx); dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
} }
template <typename SrcBuffer, typename SrcIteratorHacks> template <typename SrcBuffer, typename SrcStepHacks>
__device__ void RunRead(const SrcDesc& src_desc, __device__ void
const SrcBuffer& src_buf, RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
const SrcIteratorHacks& src_iterator_hacks)
{ {
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
...@@ -96,9 +95,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 ...@@ -96,9 +95,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
I1), I1),
SrcVectorTensorContiguousDimOrder{}); SrcVectorTensorContiguousDimOrder{});
constexpr auto src_vector_desc = make_dynamic_naive_tensor_descriptor_v2( constexpr auto src_vector_desc =
sequence_to_tuple_of_number(src_vector_tensor_lengths), make_naive_tensor_descriptor_v2(sequence_to_tuple_of_number(src_vector_tensor_lengths),
sequence_to_tuple_of_number(src_vector_tensor_strides)); sequence_to_tuple_of_number(src_vector_tensor_strides));
// access order and lengths // access order and lengths
constexpr auto src_access_lengths = SliceLengths{} / src_vector_tensor_lengths; constexpr auto src_access_lengths = SliceLengths{} / src_vector_tensor_lengths;
...@@ -108,31 +107,31 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 ...@@ -108,31 +107,31 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
constexpr auto ordered_src_access_lengths = constexpr auto ordered_src_access_lengths =
container_reorder_given_new2old(src_access_lengths, src_dim_access_order); container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
// make forward iterators // make forward steps
const auto src_forward_iterators = generate_tuple( const auto src_forward_steps = generate_tuple(
[&](auto i) { [&](auto i) {
Index forward_step; Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) { static_for<0, nDim, 1>{}([&](auto j) {
forward_step(j) = (i.value == j.value) ? src_vector_tensor_lengths[i] : 0; forward_step_idx(j) = (i.value == j.value) ? src_vector_tensor_lengths[i] : 0;
}); });
return make_dynamic_tensor_coordinate_iterator( return make_tensor_coordinate_step(
src_desc, forward_step, src_iterator_hacks[I0][i]); src_desc, forward_step_idx, src_step_hacks[I0][i]);
}, },
Number<nDim>{}); Number<nDim>{});
// make backward iterators // make backward steps
const auto src_backward_iterators = generate_tuple( const auto src_backward_steps = generate_tuple(
[&](auto i) { [&](auto i) {
Index backward_step; Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) { static_for<0, nDim, 1>{}([&](auto j) {
backward_step(j) = (i.value == j.value) ? -src_vector_tensor_lengths[i] : 0; backward_step_idx(j) = (i.value == j.value) ? -src_vector_tensor_lengths[i] : 0;
}); });
return make_dynamic_tensor_coordinate_iterator( return make_tensor_coordinate_step(
src_desc, backward_step, src_iterator_hacks[I1][i]); src_desc, backward_step_idx, src_step_hacks[I1][i]);
}, },
Number<nDim>{}); Number<nDim>{});
...@@ -219,13 +218,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 ...@@ -219,13 +218,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
{ {
if constexpr(forward_sweep[i]) if constexpr(forward_sweep[i])
{ {
move_dynamic_tensor_coordinate( move_tensor_coordinate(
src_desc, src_coord_, src_forward_iterators[src_dim_access_order[i]]); src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]);
} }
else else
{ {
move_dynamic_tensor_coordinate( move_tensor_coordinate(
src_desc, src_coord_, src_backward_iterators[src_dim_access_order[i]]); src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]);
} }
} }
}); });
...@@ -234,17 +233,16 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 ...@@ -234,17 +233,16 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
// move src coordinate back to slice origin (or not) // move src coordinate back to slice origin (or not)
if constexpr(SrcResetCoordinateAfterRun) if constexpr(SrcResetCoordinateAfterRun)
{ {
const auto src_reset_iterator = const auto src_reset_step =
make_dynamic_tensor_coordinate_iterator(src_desc, GetSrcCoordinateResetStep()); make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep());
move_dynamic_tensor_coordinate(src_desc, src_coord_, src_reset_iterator); move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
} }
} }
template <typename DstBuffer, typename DstIteratorHacks> template <typename DstBuffer, typename DstStepHacks>
__device__ void RunWrite(const DstDesc& dst_desc, __device__ void
DstBuffer& dst_buf, RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf, const DstStepHacks& dst_step_hacks)
const DstIteratorHacks& dst_iterator_hacks)
{ {
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
...@@ -265,9 +263,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 ...@@ -265,9 +263,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
I1), I1),
DstVectorTensorContiguousDimOrder{}); DstVectorTensorContiguousDimOrder{});
constexpr auto dst_vector_desc = make_dynamic_naive_tensor_descriptor_v2( constexpr auto dst_vector_desc =
sequence_to_tuple_of_number(dst_vector_tensor_lengths), make_naive_tensor_descriptor_v2(sequence_to_tuple_of_number(dst_vector_tensor_lengths),
sequence_to_tuple_of_number(dst_vector_tensor_strides)); sequence_to_tuple_of_number(dst_vector_tensor_strides));
// dst access order and lengths // dst access order and lengths
constexpr auto dst_access_lengths = SliceLengths{} / dst_vector_tensor_lengths; constexpr auto dst_access_lengths = SliceLengths{} / dst_vector_tensor_lengths;
...@@ -277,35 +275,31 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 ...@@ -277,35 +275,31 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
constexpr auto ordered_dst_access_lengths = constexpr auto ordered_dst_access_lengths =
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
// make forward iterators // make forward steps
const auto dst_forward_iterators = generate_tuple( const auto dst_forward_steps = generate_tuple(
[&](auto i) { [&](auto i) {
Index forward_step; Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) { static_for<0, nDim, 1>{}([&](auto j) {
forward_step(j) = (i.value == j.value) ? dst_vector_tensor_lengths[i] : 0; forward_step_idx(j) = (i.value == j.value) ? dst_vector_tensor_lengths[i] : 0;
}); });
const auto forward_iterator = make_dynamic_tensor_coordinate_iterator( return make_tensor_coordinate_step(
dst_desc, forward_step, dst_iterator_hacks[I0][i]); dst_desc, forward_step_idx, dst_step_hacks[I0][i]);
return forward_iterator;
}, },
Number<nDim>{}); Number<nDim>{});
// make backward iterators // make backward steps
const auto dst_backward_iterators = generate_tuple( const auto dst_backward_steps = generate_tuple(
[&](auto i) { [&](auto i) {
Index backward_step; Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) { static_for<0, nDim, 1>{}([&](auto j) {
backward_step(j) = (i.value == j.value) ? -dst_vector_tensor_lengths[i] : 0; backward_step_idx(j) = (i.value == j.value) ? -dst_vector_tensor_lengths[i] : 0;
}); });
const auto backward_iterator = make_dynamic_tensor_coordinate_iterator( return make_tensor_coordinate_step(
dst_desc, backward_step, dst_iterator_hacks[I1][i]); dst_desc, backward_step_idx, dst_step_hacks[I1][i]);
return backward_iterator;
}, },
Number<nDim>{}); Number<nDim>{});
...@@ -394,13 +388,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 ...@@ -394,13 +388,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
{ {
if constexpr(forward_sweep[i]) if constexpr(forward_sweep[i])
{ {
move_dynamic_tensor_coordinate( move_tensor_coordinate(
dst_desc, dst_coord_, dst_forward_iterators[dst_dim_access_order[i]]); dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]);
} }
else else
{ {
move_dynamic_tensor_coordinate( move_tensor_coordinate(
dst_desc, dst_coord_, dst_backward_iterators[dst_dim_access_order[i]]); dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]);
} }
} }
}); });
...@@ -409,10 +403,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 ...@@ -409,10 +403,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
// move dst coordinate back to slice origin (or not) // move dst coordinate back to slice origin (or not)
if constexpr(DstResetCoordinateAfterRun) if constexpr(DstResetCoordinateAfterRun)
{ {
const auto dst_reset_iterator = const auto dst_reset_step =
make_dynamic_tensor_coordinate_iterator(dst_desc, GetDstCoordinateResetStep()); make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep());
move_dynamic_tensor_coordinate(dst_desc, dst_coord_, dst_reset_iterator); move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
} }
} }
...@@ -423,11 +417,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 ...@@ -423,11 +417,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{}; constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{};
constexpr auto src_iterator_hacks = constexpr auto src_step_hacks =
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}), make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
generate_tuple([&](auto) { return zeros; }, Number<nDim>{})); generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
RunRead(src_desc, src_buf, src_iterator_hacks); RunRead(src_desc, src_buf, src_step_hacks);
} }
template <typename DstBuffer> template <typename DstBuffer>
...@@ -437,11 +431,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 ...@@ -437,11 +431,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{}; constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
constexpr auto dst_iterator_hacks = constexpr auto dst_step_hacks =
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}), make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
generate_tuple([&](auto) { return zeros; }, Number<nDim>{})); generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
RunWrite(dst_desc, dst_buf, dst_iterator_hacks); RunWrite(dst_desc, dst_buf, dst_step_hacks);
} }
__device__ static constexpr auto GetSrcCoordinateResetStep() __device__ static constexpr auto GetSrcCoordinateResetStep()
...@@ -564,18 +558,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 ...@@ -564,18 +558,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
: src_slice_origin_step_idx + GetSrcCoordinateResetStep(); : src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time? // is it OK to construct a new step every time?
const auto adjusted_step = const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
make_dynamic_tensor_coordinate_iterator(src_desc, adjusted_step_idx);
move_dynamic_tensor_coordinate(src_desc, src_coord_, adjusted_step); move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
} }
// src_slice_origin_step_idx need to be known at compile-time, for performance reason // src_slice_origin_step_idx need to be known at compile-time, for performance reason
template <typename SrcMoveSliceWindowIteratorHack> template <typename SrcMoveSliceWindowStepHack>
__device__ void __device__ void
MoveSrcSliceWindow(const SrcDesc& src_desc, MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& src_slice_origin_step_idx, const Index& src_slice_origin_step_idx,
const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack) const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
{ {
// if src coord was not reset by RunRead(), then need to adjust the step here // if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx = const auto adjusted_step_idx =
...@@ -583,10 +576,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 ...@@ -583,10 +576,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
: src_slice_origin_step_idx + GetSrcCoordinateResetStep(); : src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time? // is it OK to construct a new step every time?
const auto adjusted_step = make_dynamic_tensor_coordinate_iterator( const auto adjusted_step = make_tensor_coordinate_step(
src_desc, adjusted_step_idx, src_move_slice_window_iterator_hack); src_desc, adjusted_step_idx, src_move_slice_window_step_hack);
move_dynamic_tensor_coordinate(src_desc, src_coord_, adjusted_step); move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
} }
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason // dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, __device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
...@@ -598,19 +591,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 ...@@ -598,19 +591,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
: dst_slice_origin_step_idx + GetDstCoordinateResetStep(); : dst_slice_origin_step_idx + GetDstCoordinateResetStep();
// is it OK to construct a new step every time? // is it OK to construct a new step every time?
const auto adjusted_step = const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
make_dynamic_tensor_coordinate_iterator(dst_desc, adjusted_step_idx);
move_dynamic_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
} }
private: private:
static constexpr auto buffer_desc_ = static constexpr auto buffer_desc_ =
make_dynamic_naive_tensor_descriptor_packed_v2(sequence_to_tuple_of_number(SliceLengths{})); make_naive_tensor_descriptor_packed(sequence_to_tuple_of_number(SliceLengths{}));
static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize(); static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize();
StaticBuffer<AddressSpaceEnum_t::Vgpr, SrcData, buffer_size_> buffer_; StaticBuffer<AddressSpaceEnum_t::Vgpr, SrcData, buffer_size_, true> buffer_;
SrcCoord src_coord_; SrcCoord src_coord_;
DstCoord dst_coord_; DstCoord dst_coord_;
...@@ -622,25 +614,24 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 ...@@ -622,25 +614,24 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
// 2. SrcBuffer is DynamicBuffer // 2. SrcBuffer is DynamicBuffer
// 3. src_ref_idx is known at run-time // 3. src_ref_idx is known at run-time
// 4. SrcRefToOriginDisplacement is known at compile-time // 4. SrcRefToOriginDisplacement is known at compile-time
// 5. use #-iterator // 5. use #-step
// 2. dst: // 2. dst:
// 1. DstDesc is known at compile-time // 1. DstDesc is known at compile-time
// 2. DstBuffer is StaticBuffer // 2. DstBuffer is StaticBuffer
// 3. DstOriginIdx is known at compile-time // 3. DstOriginIdx is known at compile-time
// 4. use direct address calculation // 4. use direct address calculation
// 3. vector access on src // 3. vector access on src
template < template <typename SrcData,
typename SrcData, typename DstData,
typename DstData, typename SrcDesc,
typename SrcDesc, typename DstDesc,
typename DstDesc, typename SliceLengths,
typename SliceLengths, typename DimAccessOrder,
typename DimAccessOrder, typename SrcVectorTensorLengths,
typename SrcVectorTensorLengths, typename SrcVectorTensorContiguousDimOrder,
typename SrcVectorTensorContiguousDimOrder, typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
typename std::enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), bool>::type = false>
bool>::type = false> struct ThreadwiseTensorSliceTransfer_v4r1
struct ThreadwiseDynamicTensorSliceTransfer_v4r1
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -649,12 +640,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4r1 ...@@ -649,12 +640,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4r1
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
using SrcCoord = decltype(make_dynamic_tensor_coordinate(SrcDesc{}, Index{})); using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using SrcCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(SrcDesc{}, Index{})); using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v4r1(const Index& src_ref_idx) __device__ constexpr ThreadwiseTensorSliceTransfer_v4r1(const Index& src_ref_idx)
: src_ref_coord_(make_dynamic_tensor_coordinate(SrcDesc{}, src_ref_idx)) : src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx))
{ {
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc and DstDesc need to known at compile-time"); "wrong! SrcDesc and DstDesc need to known at compile-time");
...@@ -712,9 +703,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4r1 ...@@ -712,9 +703,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4r1
I1), I1),
SrcVectorTensorContiguousDimOrder{}); SrcVectorTensorContiguousDimOrder{});
constexpr auto src_vector_desc = make_dynamic_naive_tensor_descriptor_v2( constexpr auto src_vector_desc =
sequence_to_tuple_of_number(src_vector_tensor_lengths), make_naive_tensor_descriptor_v2(sequence_to_tuple_of_number(src_vector_tensor_lengths),
sequence_to_tuple_of_number(src_vector_tensor_strides)); sequence_to_tuple_of_number(src_vector_tensor_strides));
// access order and lengths // access order and lengths
constexpr auto access_lengths = SliceLengths{} / src_vector_tensor_lengths; constexpr auto access_lengths = SliceLengths{} / src_vector_tensor_lengths;
...@@ -734,13 +725,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4r1 ...@@ -734,13 +725,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4r1
constexpr auto src_ref_to_data_disp_idx = constexpr auto src_ref_to_data_disp_idx =
src_ref_to_origin_disp_idx + data_to_origin_disp_idx; src_ref_to_origin_disp_idx + data_to_origin_disp_idx;
constexpr auto src_ref_to_data_disp_coord_iterator = constexpr auto src_ref_to_data_disp_coord_step =
make_dynamic_tensor_coordinate_iterator(src_desc, src_ref_to_data_disp_idx); make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx);
auto src_data_coord = src_ref_coord_; auto src_data_coord = src_ref_coord_;
move_dynamic_tensor_coordinate( move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step);
src_desc, src_data_coord, src_ref_to_data_disp_coord_iterator);
vector_type_maker_t<SrcData, src_vector_desc.GetElementSpaceSize()> src_vector; vector_type_maker_t<SrcData, src_vector_desc.GetElementSpaceSize()> src_vector;
...@@ -775,10 +765,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4r1 ...@@ -775,10 +765,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4r1
{ {
constexpr auto src_desc = SrcDesc{}; constexpr auto src_desc = SrcDesc{};
const auto src_slice_move_step_iter = make_dynamic_tensor_coordinate_iterator( const auto src_slice_move_step_iter =
src_desc, to_multi_index(src_slice_move_step_idx)); make_tensor_coordinate_step(src_desc, to_multi_index(src_slice_move_step_idx));
move_dynamic_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter); move_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter);
} }
private: private:
......
...@@ -350,8 +350,8 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2bf16> ...@@ -350,8 +350,8 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2bf16>
class FloatC> class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{ {
const auto p_a = reinterpret_cast<const ushort2_t*>(a); const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
const auto p_b = reinterpret_cast<const ushort2_t*>(b); const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
return intrin_mfma_f32_32x32x2bf16<MPerXdlops, NPerXdlops, AStride, BStride>::run( return intrin_mfma_f32_32x32x2bf16<MPerXdlops, NPerXdlops, AStride, BStride>::run(
p_a, p_b, reg_c); p_a, p_b, reg_c);
...@@ -384,8 +384,8 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x4bf16> ...@@ -384,8 +384,8 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x4bf16>
class FloatC> class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{ {
const auto p_a = reinterpret_cast<const ushort2_t*>(a); const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
const auto p_b = reinterpret_cast<const ushort2_t*>(b); const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
return intrin_mfma_f32_32x32x4bf16(p_a, p_b, reg_c); return intrin_mfma_f32_32x32x4bf16(p_a, p_b, reg_c);
} }
...@@ -417,8 +417,8 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x8bf16> ...@@ -417,8 +417,8 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x8bf16>
class FloatC> class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{ {
const auto p_a = reinterpret_cast<const ushort2_t*>(a); const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
const auto p_b = reinterpret_cast<const ushort2_t*>(b); const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
return intrin_mfma_f32_16x16x8bf16(p_a, p_b, reg_c); return intrin_mfma_f32_16x16x8bf16(p_a, p_b, reg_c);
} }
...@@ -450,8 +450,8 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x2bf16> ...@@ -450,8 +450,8 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x2bf16>
class FloatC> class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{ {
const auto p_a = reinterpret_cast<const ushort2_t*>(a); const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
const auto p_b = reinterpret_cast<const ushort2_t*>(b); const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
return intrin_mfma_f32_16x16x2bf16<MPerXdlops, NPerXdlops>(p_a, p_b, reg_c); return intrin_mfma_f32_16x16x2bf16<MPerXdlops, NPerXdlops>(p_a, p_b, reg_c);
} }
...@@ -483,8 +483,8 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16> ...@@ -483,8 +483,8 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16>
class FloatC> class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{ {
const auto p_a = reinterpret_cast<const ushort2_t*>(a); const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
const auto p_b = reinterpret_cast<const ushort2_t*>(b); const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
return intrin_mfma_f32_4x4x2bf16<MPerXdlops, NPerXdlops>::run(p_a, p_b, reg_c); return intrin_mfma_f32_4x4x2bf16<MPerXdlops, NPerXdlops>::run(p_a, p_b, reg_c);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment