Unverified Commit 31b40352 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Merge pull request #16 from ROCmSoftwarePlatform/develop

Merge develop into master
parents 5781adf5 b62bf8c3
#ifndef CK_DYNAMIC_MULTI_INDEX_TRANSFORM_HELPER_HPP #ifndef CK_MULTI_INDEX_TRANSFORM_HELPER_HPP
#define CK_DYNAMIC_MULTI_INDEX_TRANSFORM_HELPER_HPP #define CK_MULTI_INDEX_TRANSFORM_HELPER_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_multi_index_transform.hpp" #include "multi_index_transform.hpp"
namespace ck { namespace ck {
template <typename LowLength> template <typename LowLength>
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength& low_length) __host__ __device__ constexpr auto make_pass_through_transform(const LowLength& low_length)
{ {
return DynamicPassThrough<LowLength>{low_length}; return PassThrough<LowLength>{low_length};
} }
template <typename LowLength, typename LeftPad, typename RightPad, bool SkipIsValidCheck = false> template <typename LowLength, typename LeftPad, typename RightPad, bool SkipIsValidCheck = false>
...@@ -19,47 +19,46 @@ make_pad_transform(const LowLength& low_length, ...@@ -19,47 +19,46 @@ make_pad_transform(const LowLength& low_length,
const RightPad& right_pad, const RightPad& right_pad,
integral_constant<bool, SkipIsValidCheck> = integral_constant<bool, false>{}) integral_constant<bool, SkipIsValidCheck> = integral_constant<bool, false>{})
{ {
return DynamicPad<LowLength, LeftPad, RightPad, SkipIsValidCheck>{ return Pad<LowLength, LeftPad, RightPad, SkipIsValidCheck>{low_length, left_pad, right_pad};
low_length, left_pad, right_pad};
} }
template <typename LowLength, typename LeftPad, bool SkipIsValidCheck = false> template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck = false>
__host__ __device__ constexpr auto make_left_pad_transform( __host__ __device__ constexpr auto make_left_pad_transform(
const LowLength& low_length, const LowLength& low_length,
const LeftPad& left_pad, const LeftPadLength& left_pad,
integral_constant<bool, SkipIsValidCheck> = integral_constant<bool, false>{}) integral_constant<bool, SkipIsValidCheck> = integral_constant<bool, false>{})
{ {
return DynamicLeftPad<LowLength, LeftPad, SkipIsValidCheck>{low_length, left_pad}; return LeftPad<LowLength, LeftPadLength, SkipIsValidCheck>{low_length, left_pad};
} }
template <typename LowLength, typename RightPad, bool SkipIsValidCheck> template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck>
__host__ __device__ constexpr auto make_right_pad_transform( __host__ __device__ constexpr auto make_right_pad_transform(
const LowLength& low_length, const LowLength& low_length,
const RightPad& right_pad, const RightPadLength& right_pad,
integral_constant<bool, SkipIsValidCheck> = integral_constant<bool, false>{}) integral_constant<bool, SkipIsValidCheck> = integral_constant<bool, false>{})
{ {
return DynamicRightPad<LowLength, RightPad, SkipIsValidCheck>{low_length, right_pad}; return RightPad<LowLength, RightPadLength, SkipIsValidCheck>{low_length, right_pad};
} }
template <typename UpLengths, template <typename UpLengths,
typename Coefficients, typename Coefficients,
typename std::enable_if<UpLengths::Size() == Coefficients::Size(), bool>::type = false> typename enable_if<UpLengths::Size() == Coefficients::Size(), bool>::type = false>
__host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_lengths, __host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_lengths,
const Coefficients& coefficients) const Coefficients& coefficients)
{ {
return DynamicEmbed<UpLengths, Coefficients>{up_lengths, coefficients}; return Embed<UpLengths, Coefficients>{up_lengths, coefficients};
} }
template <typename LowLengths> template <typename LowLengths>
__host__ __device__ constexpr auto make_merge_transform(const LowLengths& low_lengths) __host__ __device__ constexpr auto make_merge_transform(const LowLengths& low_lengths)
{ {
#if !CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION #if !CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION
return DynamicMerge_v1_carry_check<LowLengths>{low_lengths}; return Merge_v1_carry_check<LowLengths>{low_lengths};
#else #else
#if 1 #if 1
return DynamicMerge_v2_magic_division<LowLengths>{low_lengths}; return Merge_v2_magic_division<LowLengths>{low_lengths};
#else #else
return DynamicMerge_v2r2_magic_division<LowLengths>{low_lengths}; return Merge_v2r2_magic_division<LowLengths>{low_lengths};
#endif #endif
#endif #endif
} }
...@@ -68,7 +67,7 @@ template <typename LowLengths> ...@@ -68,7 +67,7 @@ template <typename LowLengths>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
make_merge_transform_v2_magic_division(const LowLengths& low_lengths) make_merge_transform_v2_magic_division(const LowLengths& low_lengths)
{ {
return DynamicMerge_v2_magic_division<LowLengths>{low_lengths}; return Merge_v2_magic_division<LowLengths>{low_lengths};
} }
template <typename UpLengths, bool Use24BitIntegerCalculation = false> template <typename UpLengths, bool Use24BitIntegerCalculation = false>
...@@ -76,13 +75,13 @@ __host__ __device__ constexpr auto make_unmerge_transform( ...@@ -76,13 +75,13 @@ __host__ __device__ constexpr auto make_unmerge_transform(
const UpLengths& up_lengths, const UpLengths& up_lengths,
integral_constant<bool, Use24BitIntegerCalculation> = integral_constant<bool, false>{}) integral_constant<bool, Use24BitIntegerCalculation> = integral_constant<bool, false>{})
{ {
return DynamicUnMerge<UpLengths, Use24BitIntegerCalculation>{up_lengths}; return UnMerge<UpLengths, Use24BitIntegerCalculation>{up_lengths};
} }
template <typename LowerIndex> template <typename LowerIndex>
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_idx) __host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_idx)
{ {
return DynamicFreeze<LowerIndex>{low_idx}; return Freeze<LowerIndex>{low_idx};
} }
template <typename LowLength, typename SliceBegin, typename SliceEnd> template <typename LowLength, typename SliceBegin, typename SliceEnd>
...@@ -90,14 +89,14 @@ __host__ __device__ constexpr auto make_slice_transform(const LowLength& low_len ...@@ -90,14 +89,14 @@ __host__ __device__ constexpr auto make_slice_transform(const LowLength& low_len
const SliceBegin& slice_begin, const SliceBegin& slice_begin,
const SliceEnd& slice_end) const SliceEnd& slice_end)
{ {
return DynamicSlice<LowLength, SliceBegin, SliceEnd>{low_length, slice_begin, slice_end}; return Slice<LowLength, SliceBegin, SliceEnd>{low_length, slice_begin, slice_end};
} }
template <typename VectorSize, typename UpLength> template <typename VectorSize, typename UpLength>
__host__ __device__ constexpr auto make_vectorize_transform(const VectorSize& vector_size, __host__ __device__ constexpr auto make_vectorize_transform(const VectorSize& vector_size,
const UpLength& up_length) const UpLength& up_length)
{ {
return DynamicVectorize<VectorSize, UpLength>{vector_size, up_length}; return Vectorize<VectorSize, UpLength>{vector_size, up_length};
} }
} // namespace ck } // namespace ck
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
#define CK_TENSOR_ADAPTOR_HPP #define CK_TENSOR_ADAPTOR_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
namespace ck { namespace ck {
...@@ -64,7 +64,7 @@ struct TensorAdaptor ...@@ -64,7 +64,7 @@ struct TensorAdaptor
Number<ndim_top_>{}); Number<ndim_top_>{});
// TODO: make container_reduce support tuple of Number and index_t // TODO: make container_reduce support tuple of Number and index_t
return container_reduce(lengths, math::multiplies_v2{}, Number<1>{}); return container_reduce(lengths, math::multiplies{}, Number<1>{});
} }
template <index_t IDim> template <index_t IDim>
...@@ -454,9 +454,7 @@ __host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transf ...@@ -454,9 +454,7 @@ __host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transf
remove_cv_t<decltype(top_dim_hidden_ids)>>{transforms}; remove_cv_t<decltype(top_dim_hidden_ids)>>{transforms};
} }
template <typename X, template <typename X, typename... Xs, typename enable_if<sizeof...(Xs) >= 2, bool>::type = false>
typename... Xs,
typename std::enable_if<sizeof...(Xs) >= 2, bool>::type = false>
__host__ __device__ constexpr auto chain_tensor_adaptors(const X& x, const Xs&... xs) __host__ __device__ constexpr auto chain_tensor_adaptors(const X& x, const Xs&... xs)
{ {
return chain_tensor_adaptors(x, chain_tensor_adaptors(xs...)); return chain_tensor_adaptors(x, chain_tensor_adaptors(xs...));
......
#ifndef CK_DYNAMIC_TENSOR_DESCRIPTOR_HPP #ifndef CK_TENSOR_DESCRIPTOR_HPP
#define CK_DYNAMIC_TENSOR_DESCRIPTOR_HPP #define CK_TENSOR_DESCRIPTOR_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_multi_index_transform.hpp" #include "multi_index_transform.hpp"
namespace ck { namespace ck {
template <index_t NDimHidden, typename VisibleDimensionIds> template <index_t NDimHidden, typename VisibleDimensionIds>
struct DynamicTensorCoordinate; struct TensorCoordinate;
template <index_t NTransform, index_t NDimVisible, typename UpdateLowerIndexHack> template <index_t NTransform, index_t NDimVisible, typename UpdateLowerIndexHack>
struct DynamicTensorCoordinateIterator; struct TensorCoordinateStep;
// Transforms: Tuple<transforms...> // Transforms: Tuple<transforms...>
// LowerDimensionIdss : Tuple<Sequence<...>, ...> // LowerDimensionIdss : Tuple<Sequence<...>, ...>
...@@ -21,7 +21,7 @@ template <typename Transforms, ...@@ -21,7 +21,7 @@ template <typename Transforms,
typename UpperDimensionIdss, typename UpperDimensionIdss,
typename VisibleDimensionIds, typename VisibleDimensionIds,
typename ElementSpaceSize> typename ElementSpaceSize>
struct DynamicTensorDescriptor struct TensorDescriptor
{ {
// TODO make these private // TODO make these private
__host__ __device__ static constexpr index_t GetNumOfTransform() { return Transforms::Size(); } __host__ __device__ static constexpr index_t GetNumOfTransform() { return Transforms::Size(); }
...@@ -69,7 +69,7 @@ struct DynamicTensorDescriptor ...@@ -69,7 +69,7 @@ struct DynamicTensorDescriptor
Number<ndim_visible_>{}); Number<ndim_visible_>{});
// TODO: make container_reduce support tuple of Number and index_t // TODO: make container_reduce support tuple of Number and index_t
return container_reduce(lengths, math::multiplies_v2{}, Number<1>{}); return container_reduce(lengths, math::multiplies{}, Number<1>{});
} }
template <index_t IDim> template <index_t IDim>
...@@ -105,15 +105,15 @@ struct DynamicTensorDescriptor ...@@ -105,15 +105,15 @@ struct DynamicTensorDescriptor
using VisibleIndex = MultiIndex<ndim_visible_>; using VisibleIndex = MultiIndex<ndim_visible_>;
using HiddenIndex = MultiIndex<ndim_hidden_>; using HiddenIndex = MultiIndex<ndim_hidden_>;
using Coordinate = DynamicTensorCoordinate<ndim_hidden_, VisibleDimensionIds>; using Coordinate = TensorCoordinate<ndim_hidden_, VisibleDimensionIds>;
// may be index_t or Number<> // may be index_t or Number<>
using ElementSize = remove_cv_t<decltype(InitializeElementSize(Transforms{}))>; using ElementSize = remove_cv_t<decltype(InitializeElementSize(Transforms{}))>;
public: public:
__host__ __device__ constexpr DynamicTensorDescriptor() = default; __host__ __device__ constexpr TensorDescriptor() = default;
__host__ __device__ constexpr DynamicTensorDescriptor(const Transforms& transforms, __host__ __device__ constexpr TensorDescriptor(const Transforms& transforms,
ElementSpaceSize element_space_size) ElementSpaceSize element_space_size)
: transforms_{transforms}, : transforms_{transforms},
element_size_{InitializeElementSize(transforms)}, element_size_{InitializeElementSize(transforms)},
...@@ -159,7 +159,7 @@ struct DynamicTensorDescriptor ...@@ -159,7 +159,7 @@ struct DynamicTensorDescriptor
{ {
static_assert(Idx::Size() == GetNumOfDimension(), "wrong! inconsistent # of dimension"); static_assert(Idx::Size() == GetNumOfDimension(), "wrong! inconsistent # of dimension");
return make_dynamic_tensor_coordinate(*this, idx).GetOffset(); return make_tensor_coordinate(*this, idx).GetOffset();
} }
// TODO make these private // TODO make these private
...@@ -196,7 +196,7 @@ struct DynamicTensorDescriptor ...@@ -196,7 +196,7 @@ struct DynamicTensorDescriptor
__host__ __device__ void Print() const __host__ __device__ void Print() const
{ {
printf("{"); printf("{");
printf("DynamicTensorDescriptor, "); printf("TensorDescriptor, ");
static_for<0, ntransform_, 1>{}([&](auto i) { static_for<0, ntransform_, 1>{}([&](auto i) {
printf("transforms: "); printf("transforms: ");
transforms_[i].Print(); transforms_[i].Print();
...@@ -217,7 +217,7 @@ struct DynamicTensorDescriptor ...@@ -217,7 +217,7 @@ struct DynamicTensorDescriptor
}; };
template <index_t NDimHidden, typename VisibleDimensionIds> template <index_t NDimHidden, typename VisibleDimensionIds>
struct DynamicTensorCoordinate struct TensorCoordinate
{ {
// TODO make these private // TODO make these private
static constexpr index_t ndim_visible_ = VisibleDimensionIds::Size(); static constexpr index_t ndim_visible_ = VisibleDimensionIds::Size();
...@@ -226,9 +226,9 @@ struct DynamicTensorCoordinate ...@@ -226,9 +226,9 @@ struct DynamicTensorCoordinate
using VisibleIndex = MultiIndex<ndim_visible_>; using VisibleIndex = MultiIndex<ndim_visible_>;
public: public:
__host__ __device__ constexpr DynamicTensorCoordinate() = default; __host__ __device__ constexpr TensorCoordinate() = default;
__host__ __device__ constexpr DynamicTensorCoordinate(const HiddenIndex& idx_hidden) __host__ __device__ constexpr TensorCoordinate(const HiddenIndex& idx_hidden)
: idx_hidden_{idx_hidden} : idx_hidden_{idx_hidden}
{ {
} }
...@@ -252,16 +252,16 @@ struct DynamicTensorCoordinate ...@@ -252,16 +252,16 @@ struct DynamicTensorCoordinate
}; };
template <index_t NTransform, index_t NDimVisible, typename UpdateLowerIndexHack> template <index_t NTransform, index_t NDimVisible, typename UpdateLowerIndexHack>
struct DynamicTensorCoordinateIterator struct TensorCoordinateStep
{ {
// TODO make these private // TODO make these private
using VisibleIndex = MultiIndex<NDimVisible>; using VisibleIndex = MultiIndex<NDimVisible>;
public: public:
__host__ __device__ constexpr DynamicTensorCoordinateIterator() = default; __host__ __device__ constexpr TensorCoordinateStep() = default;
__host__ __device__ constexpr DynamicTensorCoordinateIterator( __host__ __device__ constexpr TensorCoordinateStep(const VisibleIndex& idx_diff_visible,
const VisibleIndex& idx_diff_visible, const MultiIndex<NTransform>& do_transforms) const MultiIndex<NTransform>& do_transforms)
: idx_diff_visible_{idx_diff_visible}, do_transforms_{do_transforms} : idx_diff_visible_{idx_diff_visible}, do_transforms_{do_transforms}
{ {
} }
...@@ -283,7 +283,7 @@ struct DynamicTensorCoordinateIterator ...@@ -283,7 +283,7 @@ struct DynamicTensorCoordinateIterator
// TODO: How to fix this? It uses an struct instead of lambda because lambda // TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor, and to put it outside the scope where it is used // doesn't have constructor, and to put it outside the scope where it is used
// (transform_dynamic_tensor_descriptor) because template cannot be defined inside a function // (transform_tensor_descriptor) because template cannot be defined inside a function
// template // template
template <typename NewTransforms> template <typename NewTransforms>
struct lambda_get_up_dim_num struct lambda_get_up_dim_num
...@@ -301,7 +301,7 @@ template <typename OldTensorDescriptor, ...@@ -301,7 +301,7 @@ template <typename OldTensorDescriptor,
typename NewLowerDimensionOldVisibleIdss, typename NewLowerDimensionOldVisibleIdss,
typename NewUpperDimensionNewVisibleIdss> typename NewUpperDimensionNewVisibleIdss>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
transform_dynamic_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc, transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
const NewTransforms& new_transforms, const NewTransforms& new_transforms,
NewLowerDimensionOldVisibleIdss, NewLowerDimensionOldVisibleIdss,
NewUpperDimensionNewVisibleIdss) NewUpperDimensionNewVisibleIdss)
...@@ -376,7 +376,7 @@ transform_dynamic_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc, ...@@ -376,7 +376,7 @@ transform_dynamic_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
const auto element_space_size = old_tensor_desc.GetElementSpaceSize(); const auto element_space_size = old_tensor_desc.GetElementSpaceSize();
return DynamicTensorDescriptor<remove_cv_t<decltype(all_transforms)>, return TensorDescriptor<remove_cv_t<decltype(all_transforms)>,
remove_cv_t<decltype(all_low_dim_hidden_idss)>, remove_cv_t<decltype(all_low_dim_hidden_idss)>,
remove_cv_t<decltype(all_up_dim_hidden_idss)>, remove_cv_t<decltype(all_up_dim_hidden_idss)>,
remove_cv_t<decltype(new_visible_dim_hidden_ids)>, remove_cv_t<decltype(new_visible_dim_hidden_ids)>,
...@@ -385,7 +385,7 @@ transform_dynamic_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc, ...@@ -385,7 +385,7 @@ transform_dynamic_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
} }
template <typename TensorDesc, typename VisibleIndex> template <typename TensorDesc, typename VisibleIndex>
__host__ __device__ constexpr auto make_dynamic_tensor_coordinate(const TensorDesc& tensor_desc, __host__ __device__ constexpr auto make_tensor_coordinate(const TensorDesc& tensor_desc,
const VisibleIndex& idx_visible) const VisibleIndex& idx_visible)
{ {
static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(), static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(),
...@@ -416,14 +416,15 @@ __host__ __device__ constexpr auto make_dynamic_tensor_coordinate(const TensorDe ...@@ -416,14 +416,15 @@ __host__ __device__ constexpr auto make_dynamic_tensor_coordinate(const TensorDe
set_container_subset(idx_hidden, dims_low, idx_low); set_container_subset(idx_hidden, dims_low, idx_low);
}); });
return DynamicTensorCoordinate<ndim_hidden, decltype(visible_dim_ids)>{idx_hidden}; return TensorCoordinate<ndim_hidden, decltype(visible_dim_ids)>{idx_hidden};
} }
// UpdateLowerIndexHack: Sequence<...> // UpdateLowerIndexHack: Sequence<...>
// HACK: control UpdateLowerIndex // HACK: control UpdateLowerIndex
template <typename TensorDesc, typename VisibleIndex, typename UpdateLowerIndexHack> template <typename TensorDesc, typename VisibleIndex, typename UpdateLowerIndexHack>
__host__ __device__ constexpr auto make_dynamic_tensor_coordinate_iterator( __host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc&,
const TensorDesc&, const VisibleIndex& idx_diff_visible, UpdateLowerIndexHack) const VisibleIndex& idx_diff_visible,
UpdateLowerIndexHack)
{ {
static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(), static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(),
"wrong! # of dimension inconsistent"); "wrong! # of dimension inconsistent");
...@@ -470,23 +471,24 @@ __host__ __device__ constexpr auto make_dynamic_tensor_coordinate_iterator( ...@@ -470,23 +471,24 @@ __host__ __device__ constexpr auto make_dynamic_tensor_coordinate_iterator(
set_container_subset(is_non_zero_diff, dims_low, non_zero_diff_pick_low); set_container_subset(is_non_zero_diff, dims_low, non_zero_diff_pick_low);
}); });
return DynamicTensorCoordinateIterator<ntransform, ndim_visible, UpdateLowerIndexHack>{ return TensorCoordinateStep<ntransform, ndim_visible, UpdateLowerIndexHack>{idx_diff_visible,
idx_diff_visible, do_transforms}; do_transforms};
} }
template <typename TensorDesc, typename VisibleIndex> template <typename TensorDesc, typename VisibleIndex>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc&,
make_dynamic_tensor_coordinate_iterator(const TensorDesc&, const VisibleIndex& idx_diff_visible) const VisibleIndex& idx_diff_visible)
{ {
constexpr index_t ntransform = TensorDesc::GetNumOfTransform(); constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
return make_dynamic_tensor_coordinate_iterator( return make_tensor_coordinate_step(
TensorDesc{}, idx_diff_visible, typename uniform_sequence_gen<ntransform, 0>::type{}); TensorDesc{}, idx_diff_visible, typename uniform_sequence_gen<ntransform, 0>::type{});
} }
template <typename TensorDesc, typename TensorCoord, typename TensorCoordIterator> template <typename TensorDesc, typename TensorCoord, typename TensorCoordStep>
__host__ __device__ constexpr void move_dynamic_tensor_coordinate( __host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc& tensor_desc,
const TensorDesc& tensor_desc, TensorCoord& coord, const TensorCoordIterator& coord_iterator) TensorCoord& coord,
const TensorCoordStep& coord_step)
{ {
constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension(); constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension();
constexpr index_t ntransform = TensorDesc::GetNumOfTransform(); constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
...@@ -495,9 +497,8 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate( ...@@ -495,9 +497,8 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate(
auto idx_diff_hidden = make_zero_multi_index<ndim_hidden>(); auto idx_diff_hidden = make_zero_multi_index<ndim_hidden>();
// initialize visible index diff // initialize visible index diff
set_container_subset(idx_diff_hidden, set_container_subset(
TensorDesc::GetVisibleDimensionIds(), idx_diff_hidden, TensorDesc::GetVisibleDimensionIds(), coord_step.GetVisibleIndexDiff());
coord_iterator.GetVisibleIndexDiff());
// this is what needs to be updated // this is what needs to be updated
auto& idx_hidden = coord.GetHiddenIndex(); auto& idx_hidden = coord.GetHiddenIndex();
...@@ -506,13 +507,13 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate( ...@@ -506,13 +507,13 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate(
auto idx_hidden_pick_visible = auto idx_hidden_pick_visible =
get_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds()); get_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds());
idx_hidden_pick_visible += coord_iterator.GetIndexDiff(); idx_hidden_pick_visible += coord_step.GetIndexDiff();
set_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds(), idx_hidden_pick_visible); set_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds(), idx_hidden_pick_visible);
// update rest of hidden index // update rest of hidden index
static_for<ntransform - 1, -1, -1>{}([&](auto itran) { static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
if(coord_iterator.do_transforms_[itran]) if(coord_step.do_transforms_[itran])
{ {
const auto& tran = tensor_desc.GetTransforms().At(itran); const auto& tran = tensor_desc.GetTransforms().At(itran);
constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran); constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran);
...@@ -524,8 +525,8 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate( ...@@ -524,8 +525,8 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate(
MultiIndex<dims_low.Size()> idx_diff_low; MultiIndex<dims_low.Size()> idx_diff_low;
// HACK: control UpdateLowerIndex for DynamicMerge using hack // HACK: control UpdateLowerIndex for Merge using hack
constexpr index_t Hack = decltype(coord_iterator.update_lower_index_hack_)::At(itran); constexpr index_t Hack = decltype(coord_step.update_lower_index_hack_)::At(itran);
tran.UpdateLowerIndex(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number<Hack>{}); tran.UpdateLowerIndex(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number<Hack>{});
...@@ -585,11 +586,11 @@ __host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc& ...@@ -585,11 +586,11 @@ __host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc&
} }
template <typename TensorDesc> template <typename TensorDesc>
using DynamicTensorCoordinate_t = decltype(make_dynamic_tensor_coordinate( using TensorCoordinate_t = decltype(make_tensor_coordinate(
TensorDesc{}, MultiIndex<remove_cv_t<remove_reference_t<TensorDesc>>::GetNumOfDimension()>{})); TensorDesc{}, MultiIndex<remove_cv_t<remove_reference_t<TensorDesc>>::GetNumOfDimension()>{}));
template <typename TensorDesc> template <typename TensorDesc>
using DynamicTensorCoordinateIterator_t = decltype(make_dynamic_tensor_coordinate_iterator( using TensorCoordinateStep_t = decltype(make_tensor_coordinate_step(
TensorDesc{}, MultiIndex<remove_cv_t<remove_reference_t<TensorDesc>>::GetNumOfDimension()>{})); TensorDesc{}, MultiIndex<remove_cv_t<remove_reference_t<TensorDesc>>::GetNumOfDimension()>{}));
} // namespace ck } // namespace ck
......
#ifndef CK_DYNAMIC_TENSOR_DESCRIPTOR_HELPER_HPP #ifndef CK_TENSOR_DESCRIPTOR_HELPER_HPP
#define CK_DYNAMIC_TENSOR_DESCRIPTOR_HELPER_HPP #define CK_TENSOR_DESCRIPTOR_HELPER_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "dynamic_multi_index_transform_helper.hpp" #include "multi_index_transform_helper.hpp"
namespace ck { namespace ck {
...@@ -37,9 +37,8 @@ __host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengt ...@@ -37,9 +37,8 @@ __host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengt
template <typename... Lengths, template <typename... Lengths,
typename... Strides, typename... Strides,
typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false> typename enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple<Lengths...>& lengths,
make_dynamic_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths,
const Tuple<Strides...>& strides) const Tuple<Strides...>& strides)
{ {
constexpr index_t N = sizeof...(Lengths); constexpr index_t N = sizeof...(Lengths);
...@@ -75,7 +74,7 @@ make_dynamic_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths, ...@@ -75,7 +74,7 @@ make_dynamic_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths,
calculate_element_space_size_impl(lengths, strides, Number<0>{}, Number<1>{}); calculate_element_space_size_impl(lengths, strides, Number<0>{}, Number<1>{});
#endif #endif
return DynamicTensorDescriptor<remove_cv_t<decltype(transforms)>, return TensorDescriptor<remove_cv_t<decltype(transforms)>,
remove_cv_t<decltype(low_dim_hidden_idss)>, remove_cv_t<decltype(low_dim_hidden_idss)>,
remove_cv_t<decltype(up_dim_hidden_idss)>, remove_cv_t<decltype(up_dim_hidden_idss)>,
remove_cv_t<decltype(visible_dim_hidden_ids)>, remove_cv_t<decltype(visible_dim_hidden_ids)>,
...@@ -88,7 +87,7 @@ make_dynamic_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths, ...@@ -88,7 +87,7 @@ make_dynamic_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths,
// 2) Number<>, which is known at compile-time // 2) Number<>, which is known at compile-time
template <typename... Lengths> template <typename... Lengths>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
make_dynamic_naive_tensor_descriptor_packed_v2(const Tuple<Lengths...>& lengths) make_naive_tensor_descriptor_packed(const Tuple<Lengths...>& lengths)
{ {
constexpr index_t N = sizeof...(Lengths); constexpr index_t N = sizeof...(Lengths);
...@@ -101,9 +100,9 @@ make_dynamic_naive_tensor_descriptor_packed_v2(const Tuple<Lengths...>& lengths) ...@@ -101,9 +100,9 @@ make_dynamic_naive_tensor_descriptor_packed_v2(const Tuple<Lengths...>& lengths)
constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{}; constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{};
const auto element_space_size = container_reduce(lengths, math::multiplies_v2{}, Number<1>{}); const auto element_space_size = container_reduce(lengths, math::multiplies{}, Number<1>{});
return DynamicTensorDescriptor<remove_cv_t<decltype(transforms)>, return TensorDescriptor<remove_cv_t<decltype(transforms)>,
remove_cv_t<decltype(low_dim_hidden_idss)>, remove_cv_t<decltype(low_dim_hidden_idss)>,
remove_cv_t<decltype(up_dim_hidden_idss)>, remove_cv_t<decltype(up_dim_hidden_idss)>,
remove_cv_t<decltype(visible_dim_hidden_ids)>, remove_cv_t<decltype(visible_dim_hidden_ids)>,
...@@ -113,7 +112,7 @@ make_dynamic_naive_tensor_descriptor_packed_v2(const Tuple<Lengths...>& lengths) ...@@ -113,7 +112,7 @@ make_dynamic_naive_tensor_descriptor_packed_v2(const Tuple<Lengths...>& lengths)
template <typename... Lengths, typename Align> template <typename... Lengths, typename Align>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
make_dynamic_naive_tensor_descriptor_aligned_v2(const Tuple<Lengths...>& lengths, Align align) make_naive_tensor_descriptor_aligned(const Tuple<Lengths...>& lengths, Align align)
{ {
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -134,7 +133,7 @@ make_dynamic_naive_tensor_descriptor_aligned_v2(const Tuple<Lengths...>& lengths ...@@ -134,7 +133,7 @@ make_dynamic_naive_tensor_descriptor_aligned_v2(const Tuple<Lengths...>& lengths
else else
{ {
return container_reduce(lengths, return container_reduce(lengths,
math::multiplies_v2{}, math::multiplies{},
Number<stride_n_minus_2>{}, Number<stride_n_minus_2>{},
i + I1, i + I1,
Number<N - 1>{}, Number<N - 1>{},
...@@ -143,7 +142,7 @@ make_dynamic_naive_tensor_descriptor_aligned_v2(const Tuple<Lengths...>& lengths ...@@ -143,7 +142,7 @@ make_dynamic_naive_tensor_descriptor_aligned_v2(const Tuple<Lengths...>& lengths
}, },
Number<N>{}); Number<N>{});
return make_dynamic_naive_tensor_descriptor_v2(lengths, strides); return make_naive_tensor_descriptor(lengths, strides);
} }
} // namespace ck } // namespace ck
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_adaptor.hpp" #include "tensor_adaptor.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_contraction_dlops.hpp" #include "threadwise_contraction_dlops.hpp"
namespace ck { namespace ck {
...@@ -22,7 +22,8 @@ namespace ck { ...@@ -22,7 +22,8 @@ namespace ck {
// 2. CThreadBuffer is StaticBuffer // 2. CThreadBuffer is StaticBuffer
// Also assume: // Also assume:
// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization) // M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
template <index_t BlockSize, template <
index_t BlockSize,
typename FloatA, typename FloatA,
typename FloatB, typename FloatB,
typename FloatC, typename FloatC,
...@@ -37,8 +38,7 @@ template <index_t BlockSize, ...@@ -37,8 +38,7 @@ template <index_t BlockSize,
index_t M1N1ThreadClusterN101, index_t M1N1ThreadClusterN101,
index_t AThreadCopyScalarPerVector_M11, index_t AThreadCopyScalarPerVector_M11,
index_t BThreadCopyScalarPerVector_N11, index_t BThreadCopyScalarPerVector_N11,
typename std::enable_if<AKMBlockDesc::IsKnownAtCompileTime() && typename enable_if<AKMBlockDesc::IsKnownAtCompileTime() && BKNBlockDesc::IsKnownAtCompileTime(),
BKNBlockDesc::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2 struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
{ {
...@@ -71,9 +71,9 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2 ...@@ -71,9 +71,9 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
static constexpr index_t N0 = N / N1; static constexpr index_t N0 = N / N1;
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeAKM0M1BlockDescriptor(const AKMBlockDesc& a_k_m_block_desc) MakeAKM0M1BlockDescriptor(const AKMBlockDesc& /* a_k_m_block_desc */)
{ {
const auto a_k_m0_m1_block_desc = transform_dynamic_tensor_descriptor( const auto a_k_m0_m1_block_desc = transform_tensor_descriptor(
AKMBlockDesc{}, AKMBlockDesc{},
make_tuple(make_pass_through_transform(Number<K>{}), make_tuple(make_pass_through_transform(Number<K>{}),
make_unmerge_transform(make_tuple(Number<M0>{}, Number<M1>{}))), make_unmerge_transform(make_tuple(Number<M0>{}, Number<M1>{}))),
...@@ -84,9 +84,9 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2 ...@@ -84,9 +84,9 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeBKN0N1BlockDescriptor(const BKNBlockDesc& b_k_n_block_desc) MakeBKN0N1BlockDescriptor(const BKNBlockDesc& /* b_k_n_block_desc */)
{ {
const auto b_k_n0_n1_block_desc = transform_dynamic_tensor_descriptor( const auto b_k_n0_n1_block_desc = transform_tensor_descriptor(
BKNBlockDesc{}, BKNBlockDesc{},
make_tuple(make_pass_through_transform(Number<K>{}), make_tuple(make_pass_through_transform(Number<K>{}),
make_unmerge_transform(make_tuple(Number<N0>{}, Number<N1>{}))), make_unmerge_transform(make_tuple(Number<N0>{}, Number<N1>{}))),
...@@ -194,7 +194,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2 ...@@ -194,7 +194,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
typename ABlockBuffer, typename ABlockBuffer,
typename BBlockBuffer, typename BBlockBuffer,
typename CThreadBuffer> typename CThreadBuffer>
__device__ void Run(const CM0M1N0N1ThreadDesc& c_m0_m1_n0_n1_thread_desc, __device__ void Run(const CM0M1N0N1ThreadDesc& /* c_m0_m1_n0_n1_thread_desc */,
const ABlockBuffer& a_block_buf, const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf, const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const CThreadBuffer& c_thread_buf) const
...@@ -357,15 +357,14 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2 ...@@ -357,15 +357,14 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
private: private:
// A[K, M0, M1] // A[K, M0, M1]
static constexpr auto a_k_m0_m1_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto a_k_m0_m1_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<KPerThread>{}, Number<M0>{}, Number<M1PerThreadM11>{})); make_tuple(Number<KPerThread>{}, Number<M0>{}, Number<M1PerThreadM11>{}));
// B[K, N0, N1] // B[K, N0, N1]
static constexpr auto b_k_n0_n1_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto b_k_n0_n1_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<KPerThread>{}, Number<N0>{}, Number<N1PerThreadN11>{})); make_tuple(Number<KPerThread>{}, Number<N0>{}, Number<N1PerThreadN11>{}));
using AThreadCopy = using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
ThreadwiseDynamicTensorSliceTransfer_v4<FloatA,
FloatA, FloatA,
decltype(a_k_m0_m1_block_desc_), decltype(a_k_m0_m1_block_desc_),
decltype(a_k_m0_m1_thread_desc_), decltype(a_k_m0_m1_thread_desc_),
...@@ -375,8 +374,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2 ...@@ -375,8 +374,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
AThreadCopyScalarPerVector_M11, AThreadCopyScalarPerVector_M11,
1>; 1>;
using BThreadCopy = using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatB,
ThreadwiseDynamicTensorSliceTransfer_v4<FloatB,
FloatB, FloatB,
decltype(b_k_n0_n1_block_desc_), decltype(b_k_n0_n1_block_desc_),
decltype(b_k_n0_n1_thread_desc_), decltype(b_k_n0_n1_thread_desc_),
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_adaptor.hpp" #include "tensor_adaptor.hpp"
#include "threadwise_dynamic_tensor_slice_transfer_v2.hpp" #include "threadwise_tensor_slice_transfer_v2.hpp"
#include "threadwise_contraction_dlops.hpp" #include "threadwise_contraction_dlops.hpp"
namespace ck { namespace ck {
...@@ -38,7 +38,7 @@ template <index_t BlockSize, ...@@ -38,7 +38,7 @@ template <index_t BlockSize,
// BM10BN10ThreadClusterBN101, ...> // BM10BN10ThreadClusterBN101, ...>
index_t AThreadCopyScalarPerVector_BM11, index_t AThreadCopyScalarPerVector_BM11,
index_t BThreadCopyScalarPerVector_BN11, index_t BThreadCopyScalarPerVector_BN11,
typename std::enable_if<ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() && typename enable_if<ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(), BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2 struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
...@@ -75,7 +75,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B ...@@ -75,7 +75,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeABlockDescriptor_BK0_BM0_BM1_BK1(const ABlockDesc_BK0_BM_BK1& a_block_desc_bk0_bm_bk1) MakeABlockDescriptor_BK0_BM0_BM1_BK1(const ABlockDesc_BK0_BM_BK1& a_block_desc_bk0_bm_bk1)
{ {
const auto a_block_bk0_bm0_bm1_bk1 = transform_dynamic_tensor_descriptor( const auto a_block_bk0_bm0_bm1_bk1 = transform_tensor_descriptor(
a_block_desc_bk0_bm_bk1, a_block_desc_bk0_bm_bk1,
make_tuple(make_pass_through_transform(Number<BK0>{}), make_tuple(make_pass_through_transform(Number<BK0>{}),
make_unmerge_transform(make_tuple(Number<BM0>{}, Number<BM1>{})), make_unmerge_transform(make_tuple(Number<BM0>{}, Number<BM1>{})),
...@@ -89,7 +89,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B ...@@ -89,7 +89,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeBBlockDescriptor_BK0_BN0_BN1_BK1(const BBlockDesc_BK0_BN_BK1& b_block_desc_bk0_bn_bk1) MakeBBlockDescriptor_BK0_BN0_BN1_BK1(const BBlockDesc_BK0_BN_BK1& b_block_desc_bk0_bn_bk1)
{ {
const auto b_block_desc_bk0_bn0_bn1_bk1 = transform_dynamic_tensor_descriptor( const auto b_block_desc_bk0_bn0_bn1_bk1 = transform_tensor_descriptor(
b_block_desc_bk0_bn_bk1, b_block_desc_bk0_bn_bk1,
make_tuple(make_pass_through_transform(Number<BK0>{}), make_tuple(make_pass_through_transform(Number<BK0>{}),
make_unmerge_transform(make_tuple(Number<BN0>{}, Number<BN1>{})), make_unmerge_transform(make_tuple(Number<BN0>{}, Number<BN1>{})),
...@@ -372,15 +372,15 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B ...@@ -372,15 +372,15 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
private: private:
// A[BK0, BM0, BM1, BK1] // A[BK0, BM0, BM1, BK1]
static constexpr auto a_thread_desc_bk0_bm0_bm1_bk1_ = static constexpr auto a_thread_desc_bk0_bm0_bm1_bk1_ =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( make_naive_tensor_descriptor_packed(make_tuple(
Number<BK0PerThread>{}, Number<BM0>{}, Number<BM1PerThreadBM11>{}, Number<BK1>{})); Number<BK0PerThread>{}, Number<BM0>{}, Number<BM1PerThreadBM11>{}, Number<BK1>{}));
// B[BK0, BN0, BN1, BK1] // B[BK0, BN0, BN1, BK1]
static constexpr auto b_thread_desc_bk0_bn0_bn1_bk1_ = static constexpr auto b_thread_desc_bk0_bn0_bn1_bk1_ =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( make_naive_tensor_descriptor_packed(make_tuple(
Number<BK0PerThread>{}, Number<BN0>{}, Number<BN1PerThreadBN11>{}, Number<BK1>{})); Number<BK0PerThread>{}, Number<BN0>{}, Number<BN1PerThreadBN11>{}, Number<BK1>{}));
using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4r1< using AThreadCopy = ThreadwiseTensorSliceTransfer_v4r1<
FloatA, FloatA,
FloatA, FloatA,
decltype(a_block_desc_bk0_bm0_bm1_bk1_), decltype(a_block_desc_bk0_bm0_bm1_bk1_),
...@@ -390,7 +390,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B ...@@ -390,7 +390,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
Sequence<1, 1, BM1PerThreadBM11, BK1>, // SrcVectorTensorLengths Sequence<1, 1, BM1PerThreadBM11, BK1>, // SrcVectorTensorLengths
Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4r1< using BThreadCopy = ThreadwiseTensorSliceTransfer_v4r1<
FloatB, FloatB,
FloatB, FloatB,
decltype(b_block_desc_bk0_bn0_bn1_bk1_), decltype(b_block_desc_bk0_bn0_bn1_bk1_),
......
...@@ -31,17 +31,16 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -31,17 +31,16 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
// HACK: fix this @Jing Zhang // HACK: fix this @Jing Zhang
static constexpr index_t KPerThreadSubC = 4; static constexpr index_t KPerThreadSubC = 4;
static constexpr auto a_thread_mtx_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto a_thread_mtx_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<EPerThreadLoop>{}, Number<KPerThreadSubC>{})); make_tuple(Number<EPerThreadLoop>{}, Number<KPerThreadSubC>{}));
static constexpr auto b_thread_mtx_ = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( static constexpr auto b_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple(
Number<EPerThreadLoop>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{})); Number<EPerThreadLoop>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
static constexpr auto c_thread_mtx_ = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( static constexpr auto c_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple(
Number<KPerThreadSubC>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{})); Number<KPerThreadSubC>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
using AThreadCopy = using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
ThreadwiseDynamicTensorSliceTransfer_v4<FloatA,
FloatA, FloatA,
BlockMatrixA, BlockMatrixA,
decltype(a_thread_mtx_), decltype(a_thread_mtx_),
...@@ -69,7 +68,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -69,7 +68,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
"wrong! K dimension not consistent\n"); "wrong! K dimension not consistent\n");
constexpr index_t K = BlockMatrixA{}.GetLength(I1); // A is transposed constexpr index_t K = BlockMatrixA{}.GetLength(I1); // A is transposed
constexpr index_t N = BlockMatrixB{}.GetLength(I1);
constexpr index_t H = BlockMatrixB{}.GetLength(I2); constexpr index_t H = BlockMatrixB{}.GetLength(I2);
constexpr index_t W = BlockMatrixB{}.GetLength(I3); constexpr index_t W = BlockMatrixB{}.GetLength(I3);
...@@ -121,9 +119,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -121,9 +119,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
"wrong! inconsistent type"); "wrong! inconsistent type");
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto a_block_mtx = BlockMatrixA{}; constexpr auto a_block_mtx = BlockMatrixA{};
...@@ -138,7 +133,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -138,7 +133,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
static_assert(WPerThread % WoPerThreadSubC == 0, ""); static_assert(WPerThread % WoPerThreadSubC == 0, "");
// thread A buffer for GEMM // thread A buffer for GEMM
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize()> StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize(), true>
a_thread_buf; a_thread_buf;
constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3<FloatA, constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3<FloatA,
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define CK_BLOCKWISE_GEMM_XDLOPS_HPP #define CK_BLOCKWISE_GEMM_XDLOPS_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "xdlops_gemm.hpp" #include "xdlops_gemm.hpp"
namespace ck { namespace ck {
...@@ -52,7 +52,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -52,7 +52,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
const index_t waveId = thread_id / WaveSize; const index_t waveId = thread_id / WaveSize;
const index_t laneId = thread_id % WaveSize; const index_t laneId = thread_id % WaveSize;
const index_t waveId_m = waveId / NWaves; const index_t waveId_m = waveId / NWaves;
const index_t waveId_n = waveId % NWaves;
if constexpr(xdlops_gemm.IsKReduction) if constexpr(xdlops_gemm.IsKReduction)
{ {
...@@ -73,7 +72,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -73,7 +72,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
const index_t thread_id = get_thread_local_1d_id(); const index_t thread_id = get_thread_local_1d_id();
const index_t waveId = thread_id / WaveSize; const index_t waveId = thread_id / WaveSize;
const index_t laneId = thread_id % WaveSize; const index_t laneId = thread_id % WaveSize;
const index_t waveId_m = waveId / NWaves;
const index_t waveId_n = waveId % NWaves; const index_t waveId_n = waveId % NWaves;
if constexpr(xdlops_gemm.IsKReduction) if constexpr(xdlops_gemm.IsKReduction)
...@@ -193,17 +191,17 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -193,17 +191,17 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
private: private:
// A[K, M] // A[K, M]
static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto a_thread_desc_ =
make_tuple(I1, Number<MRepeat>{}, I1, Number<K1>{})); make_naive_tensor_descriptor_packed(make_tuple(I1, Number<MRepeat>{}, I1, Number<K1>{}));
// B[K, N] // B[K, N]
static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto b_thread_desc_ =
make_tuple(I1, Number<NRepeat>{}, I1, Number<K1>{})); make_naive_tensor_descriptor_packed(make_tuple(I1, Number<NRepeat>{}, I1, Number<K1>{}));
static constexpr auto c_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto c_thread_desc_ =
make_tuple(Number<MRepeat>{}, Number<NRepeat>{})); make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatAB, using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB, FloatAB,
ABlockDesc, ABlockDesc,
decltype(a_thread_desc_), decltype(a_thread_desc_),
...@@ -213,7 +211,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -213,7 +211,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
K1, K1,
1>; 1>;
using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatAB, using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB, FloatAB,
BBlockDesc, BBlockDesc,
decltype(b_thread_desc_), decltype(b_thread_desc_),
...@@ -272,7 +270,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline ...@@ -272,7 +270,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
const index_t waveId = thread_id / WaveSize; const index_t waveId = thread_id / WaveSize;
const index_t laneId = thread_id % WaveSize; const index_t laneId = thread_id % WaveSize;
const index_t waveId_m = waveId / NWaves; const index_t waveId_m = waveId / NWaves;
const index_t waveId_n = waveId % NWaves;
if constexpr(xdlops_gemm.IsKReduction) if constexpr(xdlops_gemm.IsKReduction)
{ {
...@@ -293,7 +290,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline ...@@ -293,7 +290,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
const index_t thread_id = get_thread_local_1d_id(); const index_t thread_id = get_thread_local_1d_id();
const index_t waveId = thread_id / WaveSize; const index_t waveId = thread_id / WaveSize;
const index_t laneId = thread_id % WaveSize; const index_t laneId = thread_id % WaveSize;
const index_t waveId_m = waveId / NWaves;
const index_t waveId_n = waveId % NWaves; const index_t waveId_n = waveId % NWaves;
if constexpr(xdlops_gemm.IsKReduction) if constexpr(xdlops_gemm.IsKReduction)
...@@ -490,17 +486,17 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline ...@@ -490,17 +486,17 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
private: private:
// A[K, M] // A[K, M]
static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto a_thread_desc_ =
make_tuple(I1, Number<MRepeat>{}, I1, Number<K1>{})); make_naive_tensor_descriptor_packed(make_tuple(I1, Number<MRepeat>{}, I1, Number<K1>{}));
// B[K, N] // B[K, N]
static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto b_thread_desc_ =
make_tuple(I1, Number<NRepeat>{}, I1, Number<K1>{})); make_naive_tensor_descriptor_packed(make_tuple(I1, Number<NRepeat>{}, I1, Number<K1>{}));
static constexpr auto c_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto c_thread_desc_ =
make_tuple(Number<MRepeat>{}, Number<NRepeat>{})); make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatAB, using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB, FloatAB,
ABlockDesc, ABlockDesc,
decltype(a_thread_desc_), decltype(a_thread_desc_),
...@@ -510,7 +506,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline ...@@ -510,7 +506,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
1, // K1, 1, // K1,
1>; 1>;
using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatAB, using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB, FloatAB,
BBlockDesc, BBlockDesc,
decltype(b_thread_desc_), decltype(b_thread_desc_),
......
#ifndef CK_BLOCKWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_HPP #ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_HPP
#define CK_BLOCKWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_HPP #define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp" #include "cluster_descriptor.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
namespace ck { namespace ck {
// this version does following things to avoid scratch memory issue // this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer // 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseDynamicTensorSliceTransfer_v3 does not keep reference to tensor descriptor // 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseDynamicTensorSliceTransfer_v3::Run() does not construct new tensor coordinate // 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize, template <index_t BlockSize,
InMemoryDataOperationEnum_t DstInMemOp, InMemoryDataOperationEnum_t DstInMemOp,
typename BlockSliceLengths, typename BlockSliceLengths,
...@@ -33,13 +33,13 @@ template <index_t BlockSize, ...@@ -33,13 +33,13 @@ template <index_t BlockSize,
index_t DstScalarStrideInVector, index_t DstScalarStrideInVector,
bool ThreadTransferSrcResetCoordinateAfterRun, bool ThreadTransferSrcResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun> bool ThreadTransferDstResetCoordinateAfterRun>
struct BlockwiseDynamicTensorSliceTransfer_v4 struct BlockwiseTensorSliceTransfer_v4
{ {
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension(); static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseDynamicTensorSliceTransfer_v4(const SrcDesc& src_desc, __device__ constexpr BlockwiseTensorSliceTransfer_v4(const SrcDesc& src_desc,
const Index& src_block_slice_origin, const Index& src_block_slice_origin,
const DstDesc& dst_desc, const DstDesc& dst_desc,
const Index& dst_block_slice_origin) const Index& dst_block_slice_origin)
...@@ -77,15 +77,14 @@ struct BlockwiseDynamicTensorSliceTransfer_v4 ...@@ -77,15 +77,14 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
} }
} }
template <typename SrcBuffer, typename SrcIteratorHacks> template <typename SrcBuffer, typename SrcStepHacks>
__device__ void RunRead(const SrcDesc& src_desc, __device__ void
const SrcBuffer& src_buf, RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
const SrcIteratorHacks& src_iterator_hacks)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.RunRead(src_desc, src_buf, src_iterator_hacks); threadwise_transfer_.RunRead(src_desc, src_buf, src_step_hacks);
} }
} }
...@@ -118,18 +117,18 @@ struct BlockwiseDynamicTensorSliceTransfer_v4 ...@@ -118,18 +117,18 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
} }
} }
// SrcMoveSliceWindowIteratorHack to control index calculation move slice window // SrcMoveSliceWindowStepHack to control index calculation move slice window
template <typename SrcMoveSliceWindowIteratorHack> template <typename SrcMoveSliceWindowStepHack>
__device__ void __device__ void
MoveSrcSliceWindow(const SrcDesc& src_desc, MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& step, const Index& step,
const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack) const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.MoveSrcSliceWindow( threadwise_transfer_.MoveSrcSliceWindow(
src_desc, step, src_move_slice_window_iterator_hack); src_desc, step, src_move_slice_window_step_hack);
} }
} }
...@@ -144,10 +143,10 @@ struct BlockwiseDynamicTensorSliceTransfer_v4 ...@@ -144,10 +143,10 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
private: private:
static constexpr auto thread_cluster_desc_ = static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor_v2(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer = using ThreadwiseTransfer =
ThreadwiseDynamicTensorSliceTransfer_v3<ThreadSliceLengths, ThreadwiseTensorSliceTransfer_v3<ThreadSliceLengths,
DstInMemOp, DstInMemOp,
SrcData, SrcData,
DstData, DstData,
......
#ifndef CK_BLOCKWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_V2_HPP #ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V2_HPP
#define CK_BLOCKWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_V2_HPP #define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V2_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp" #include "cluster_descriptor.hpp"
#include "threadwise_dynamic_tensor_slice_transfer_v2.hpp" #include "threadwise_tensor_slice_transfer_v2.hpp"
namespace ck { namespace ck {
// this version does following things to avoid scratch memory issue // this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer // 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseDynamicTensorSliceTransfer_v3 does not keep reference to tensor descriptor // 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseDynamicTensorSliceTransfer_v3::Run() does not construct new tensor coordinate // 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize, template <index_t BlockSize,
InMemoryDataOperationEnum_t DstInMemOp, InMemoryDataOperationEnum_t DstInMemOp,
typename BlockSliceLengths, typename BlockSliceLengths,
...@@ -31,14 +31,13 @@ template <index_t BlockSize, ...@@ -31,14 +31,13 @@ template <index_t BlockSize,
typename DstVectorTensorContiguousDimOrder, typename DstVectorTensorContiguousDimOrder,
bool ThreadTransferSrcResetCoordinateAfterRun, bool ThreadTransferSrcResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun> bool ThreadTransferDstResetCoordinateAfterRun>
struct BlockwiseDynamicTensorSliceTransfer_v4r1 struct BlockwiseTensorSliceTransfer_v4r1
{ {
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension(); static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseDynamicTensorSliceTransfer_v4r1( __device__ constexpr BlockwiseTensorSliceTransfer_v4r1(const SrcDesc& src_desc,
const SrcDesc& src_desc,
const Index& src_block_slice_origin, const Index& src_block_slice_origin,
const DstDesc& dst_desc, const DstDesc& dst_desc,
const Index& dst_block_slice_origin) const Index& dst_block_slice_origin)
...@@ -76,15 +75,14 @@ struct BlockwiseDynamicTensorSliceTransfer_v4r1 ...@@ -76,15 +75,14 @@ struct BlockwiseDynamicTensorSliceTransfer_v4r1
} }
} }
template <typename SrcBuffer, typename SrcIteratorHacks> template <typename SrcBuffer, typename SrcStepHacks>
__device__ void RunRead(const SrcDesc& src_desc, __device__ void
const SrcBuffer& src_buf, RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
const SrcIteratorHacks& src_iterator_hacks)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.RunRead(src_desc, src_buf, src_iterator_hacks); threadwise_transfer_.RunRead(src_desc, src_buf, src_step_hacks);
} }
} }
...@@ -107,18 +105,18 @@ struct BlockwiseDynamicTensorSliceTransfer_v4r1 ...@@ -107,18 +105,18 @@ struct BlockwiseDynamicTensorSliceTransfer_v4r1
} }
} }
// SrcMoveSliceWindowIteratorHack to control index calculation move slice window // SrcMoveSliceWindowStepHack to control index calculation move slice window
template <typename SrcMoveSliceWindowIteratorHack> template <typename SrcMoveSliceWindowStepHack>
__device__ void __device__ void
MoveSrcSliceWindow(const SrcDesc& src_desc, MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& step, const Index& step,
const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack) const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.MoveSrcSliceWindow( threadwise_transfer_.MoveSrcSliceWindow(
src_desc, step, src_move_slice_window_iterator_hack); src_desc, step, src_move_slice_window_step_hack);
} }
} }
...@@ -133,10 +131,10 @@ struct BlockwiseDynamicTensorSliceTransfer_v4r1 ...@@ -133,10 +131,10 @@ struct BlockwiseDynamicTensorSliceTransfer_v4r1
private: private:
static constexpr auto thread_cluster_desc_ = static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor_v2(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer = using ThreadwiseTransfer =
ThreadwiseDynamicTensorSliceTransfer_v3r1<ThreadSliceLengths, ThreadwiseTensorSliceTransfer_v3r1<ThreadSliceLengths,
DstInMemOp, DstInMemOp,
SrcData, SrcData,
DstData, DstData,
......
#ifndef CK_GRIDWISE_DYNAMIC_CONTRACTION_DLOPS_V1R2_HPP #ifndef CK_GRIDWISE_CONTRACTION_DLOPS_V1R2_HPP
#define CK_GRIDWISE_DYNAMIC_CONTRACTION_DLOPS_V1R2_HPP #define CK_GRIDWISE_CONTRACTION_DLOPS_V1R2_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_multi_index_transform_helper.hpp" #include "multi_index_transform_helper.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_dlops_v2r3.hpp" #include "blockwise_gemm_dlops_v2r3.hpp"
#include "blockwise_dynamic_tensor_slice_transfer_v2.hpp" #include "blockwise_tensor_slice_transfer_v2.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_set.hpp" #include "threadwise_tensor_slice_set.hpp"
namespace ck { namespace ck {
...@@ -25,7 +25,7 @@ __global__ void ...@@ -25,7 +25,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_dynamic_contraction_dlops_v1r2( kernel_contraction_dlops_v1r2(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
...@@ -84,12 +84,12 @@ template <index_t BlockSize, ...@@ -84,12 +84,12 @@ template <index_t BlockSize,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector, index_t CThreadTransferDstScalarPerVector,
typename AGridIteratorHacks, typename AGridStepHacks,
typename BGridIteratorHacks, typename BGridStepHacks,
typename CGridIteratorHacks, typename CGridStepHacks,
typename AGridMoveSliceWindowIteratorHacks, typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowIteratorHacks> typename BGridMoveSliceWindowStepHacks>
struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1 struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -110,15 +110,13 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 ...@@ -110,15 +110,13 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 = constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 = make_naive_tensor_descriptor_aligned(
make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<GK0PerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}, GK1), make_tuple(Number<GK0PerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}, GK1),
max_lds_align); max_lds_align);
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 = constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 = make_naive_tensor_descriptor_aligned(
make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<GK0PerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}, GK1), make_tuple(Number<GK0PerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}, GK1),
max_lds_align); max_lds_align);
...@@ -201,7 +199,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 ...@@ -201,7 +199,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
const auto GM11 = Number<GM1PerBlockGM11>{}; const auto GM11 = Number<GM1PerBlockGM11>{};
const auto GM10 = GM1 / GM11; const auto GM10 = GM1 / GM11;
const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = transform_dynamic_tensor_descriptor( const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = transform_tensor_descriptor(
a_grid_desc_gk0_gm0_gm1_gk1, a_grid_desc_gk0_gm0_gm1_gk1,
make_tuple(make_pass_through_transform(GK0), make_tuple(make_pass_through_transform(GK0),
make_pass_through_transform(GM0), make_pass_through_transform(GM0),
...@@ -222,7 +220,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 ...@@ -222,7 +220,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
const auto GN11 = Number<GN1PerBlockGN11>{}; const auto GN11 = Number<GN1PerBlockGN11>{};
const auto GN10 = GN1 / GN11; const auto GN10 = GN1 / GN11;
const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = transform_dynamic_tensor_descriptor( const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = transform_tensor_descriptor(
b_grid_desc_gk0_gn0_gn1_gk1, b_grid_desc_gk0_gn0_gn1_gk1,
make_tuple(make_pass_through_transform(GK0), make_tuple(make_pass_through_transform(GK0),
make_pass_through_transform(GN0), make_pass_through_transform(GN0),
...@@ -250,16 +248,16 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 ...@@ -250,16 +248,16 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
constexpr auto BN = GN0 * GN11; constexpr auto BN = GN0 * GN11;
constexpr auto BM1 = constexpr auto BM1 =
Number<container_reduce(BM10BN10ThreadClusterBM10Xs{}, math::multiplies_v2{}, I1) * Number<container_reduce(BM10BN10ThreadClusterBM10Xs{}, math::multiplies{}, I1) *
BM1PerThreadBM11>{}; BM1PerThreadBM11>{};
constexpr auto BN1 = constexpr auto BN1 =
Number<container_reduce(BM10BN10ThreadClusterBN10Xs{}, math::multiplies_v2{}, I1) * Number<container_reduce(BM10BN10ThreadClusterBN10Xs{}, math::multiplies{}, I1) *
BN1PerThreadBN11>{}; BN1PerThreadBN11>{};
constexpr auto BM0 = BM / BM1; constexpr auto BM0 = BM / BM1;
constexpr auto BN0 = BN / BN1; constexpr auto BN0 = BN / BN1;
const auto c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc = transform_dynamic_tensor_descriptor( const auto c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc = transform_tensor_descriptor(
c_grid_desc_gm0_gm1_gn0_gn1, c_grid_desc_gm0_gm1_gn0_gn1,
make_tuple(make_pass_through_transform(GM0), make_tuple(make_pass_through_transform(GM0),
make_unmerge_transform(make_tuple(GM10, GM11)), make_unmerge_transform(make_tuple(GM10, GM11)),
...@@ -268,7 +266,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 ...@@ -268,7 +266,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}));
const auto c_gm10_bm_gn10_bn_grid_desc = transform_dynamic_tensor_descriptor( const auto c_gm10_bm_gn10_bn_grid_desc = transform_tensor_descriptor(
c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc, c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc,
make_tuple(make_pass_through_transform(GM10), make_tuple(make_pass_through_transform(GM10),
make_merge_transform(make_tuple(GM0, GM11)), make_merge_transform(make_tuple(GM0, GM11)),
...@@ -277,7 +275,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 ...@@ -277,7 +275,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<4>{}, Sequence<3, 5>{}), make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = transform_dynamic_tensor_descriptor( const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = transform_tensor_descriptor(
c_gm10_bm_gn10_bn_grid_desc, c_gm10_bm_gn10_bn_grid_desc,
make_tuple(make_pass_through_transform(GM10), make_tuple(make_pass_through_transform(GM10),
make_unmerge_transform(make_tuple(BM0, BM1)), make_unmerge_transform(make_tuple(BM0, BM1)),
...@@ -356,26 +354,24 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 ...@@ -356,26 +354,24 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 = constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 = make_naive_tensor_descriptor_aligned(
make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<GK0PerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}, GK1), make_tuple(Number<GK0PerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}, GK1),
max_lds_align); max_lds_align);
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 = constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 = make_naive_tensor_descriptor_aligned(
make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<GK0PerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}, GK1), make_tuple(Number<GK0PerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}, GK1),
max_lds_align); max_lds_align);
// A matrix in LDS memory for blockwise GEMM // A matrix in LDS memory for blockwise GEMM
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto a_block_desc_gk0_bm_gk1 = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto a_block_desc_gk0_bm_gk1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<GK0PerBlock>{}, GM0 * Number<GM1PerBlockGM11>{}, GK1), max_lds_align); make_tuple(Number<GK0PerBlock>{}, GM0 * Number<GM1PerBlockGM11>{}, GK1), max_lds_align);
// B matrix in LDS memory for blockwise GEMM // B matrix in LDS memory for blockwise GEMM
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto b_block_desc_gk0_bn_gk1 = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto b_block_desc_gk0_bn_gk1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<GK0PerBlock>{}, GN0 * Number<GN1PerBlockGN11>{}, GK1), max_lds_align); make_tuple(Number<GK0PerBlock>{}, GN0 * Number<GN1PerBlockGN11>{}, GK1), max_lds_align);
static_assert(a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize() == static_assert(a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize() ==
...@@ -385,7 +381,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 ...@@ -385,7 +381,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
"wrong!"); "wrong!");
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1< auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1<
BlockSize, BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<GK0PerBlock, GM0, 1, GM1PerBlockGM11, GK1.value>, Sequence<GK0PerBlock, GM0, 1, GM1PerBlockGM11, GK1.value>,
...@@ -409,7 +405,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 ...@@ -409,7 +405,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
make_multi_index(0, 0, 0, 0, 0)); make_multi_index(0, 0, 0, 0, 0));
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1< auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1<
BlockSize, BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<GK0PerBlock, GN0, 1, GN1PerBlockGN11, GK1.value>, Sequence<GK0PerBlock, GN0, 1, GN1PerBlockGN11, GK1.value>,
...@@ -457,8 +453,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 ...@@ -457,8 +453,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
constexpr auto c_thread_tensor_lengths_bm0_bm1_bn0_bn1 = constexpr auto c_thread_tensor_lengths_bm0_bm1_bn0_bn1 =
decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1(); decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
constexpr auto c_thread_desc_bm0_bm1_bn0_bn1 = constexpr auto c_thread_desc_bm0_bm1_bn0_bn1 = make_naive_tensor_descriptor_packed(
make_dynamic_naive_tensor_descriptor_packed_v2(
sequence_to_tuple_of_number(c_thread_tensor_lengths_bm0_bm1_bn0_bn1)); sequence_to_tuple_of_number(c_thread_tensor_lengths_bm0_bm1_bn0_bn1));
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
...@@ -475,7 +470,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 ...@@ -475,7 +470,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>( auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>(
c_thread_desc_bm0_bm1_bn0_bn1.GetElementSpaceSize()); c_thread_desc_bm0_bm1_bn0_bn1.GetElementSpaceSize());
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc, ThreadwiseTensorSliceSet_v1<FloatAcc,
decltype(c_thread_desc_bm0_bm1_bn0_bn1), decltype(c_thread_desc_bm0_bm1_bn0_bn1),
decltype(c_thread_tensor_lengths_bm0_bm1_bn0_bn1)>{} decltype(c_thread_tensor_lengths_bm0_bm1_bn0_bn1)>{}
.Run(c_thread_desc_bm0_bm1_bn0_bn1, .Run(c_thread_desc_bm0_bm1_bn0_bn1,
...@@ -501,9 +496,9 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 ...@@ -501,9 +496,9 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{}); a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{}); b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_even_buf); a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_even_buf);
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_even_buf); b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_even_buf);
...@@ -520,18 +515,18 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 ...@@ -520,18 +515,18 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
// even iteration // even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
a_block_slice_copy_step, a_block_slice_copy_step,
AGridMoveSliceWindowIteratorHacks{}); AGridMoveSliceWindowStepHacks{});
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
b_block_slice_copy_step, b_block_slice_copy_step,
BGridMoveSliceWindowIteratorHacks{}); BGridMoveSliceWindowStepHacks{});
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{}); a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{}); b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(c_thread_desc_bm0_bm1_bn0_bn1, blockwise_gemm.Run(c_thread_desc_bm0_bm1_bn0_bn1,
...@@ -546,18 +541,18 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 ...@@ -546,18 +541,18 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
// odd iteration // odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
a_block_slice_copy_step, a_block_slice_copy_step,
AGridMoveSliceWindowIteratorHacks{}); AGridMoveSliceWindowStepHacks{});
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
b_block_slice_copy_step, b_block_slice_copy_step,
BGridMoveSliceWindowIteratorHacks{}); BGridMoveSliceWindowStepHacks{});
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{}); a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{}); b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run( blockwise_gemm.Run(
...@@ -576,18 +571,18 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 ...@@ -576,18 +571,18 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
{ {
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
a_block_slice_copy_step, a_block_slice_copy_step,
AGridMoveSliceWindowIteratorHacks{}); AGridMoveSliceWindowStepHacks{});
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
b_block_slice_copy_step, b_block_slice_copy_step,
BGridMoveSliceWindowIteratorHacks{}); BGridMoveSliceWindowStepHacks{});
__syncthreads(); __syncthreads();
// LDS double buffer: load last data from device mem // LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{}); a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{}); b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run( blockwise_gemm.Run(
...@@ -615,7 +610,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 ...@@ -615,7 +610,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
// output: register to global memory // output: register to global memory
{ {
constexpr auto c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1 = constexpr auto c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1 =
make_dynamic_naive_tensor_descriptor_packed_v2( make_naive_tensor_descriptor_packed(
make_tuple(I1, make_tuple(I1,
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I0]>{}, Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I0]>{},
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I1]>{}, Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I1]>{},
...@@ -627,7 +622,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 ...@@ -627,7 +622,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
get_thread_local_1d_id()); get_thread_local_1d_id());
ThreadwiseDynamicTensorSliceTransfer_v1r3< ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc, FloatAcc,
FloatC, FloatC,
decltype(c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1), decltype(c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1),
...@@ -655,7 +650,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 ...@@ -655,7 +650,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
c_thread_buf, c_thread_buf,
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
c_grid_buf, c_grid_buf,
CGridIteratorHacks{}); CGridStepHacks{});
} }
} }
}; };
......
#ifndef CK_GRIDWISE_DYNAMIC_GEMM_DLOPS_V1R2_HPP #ifndef CK_GRIDWISE_GEMM_DLOPS_V1R2_HPP
#define CK_GRIDWISE_DYNAMIC_GEMM_DLOPS_V1R2_HPP #define CK_GRIDWISE_GEMM_DLOPS_V1R2_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_multi_index_transform_helper.hpp" #include "multi_index_transform_helper.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_dlops_v2r2.hpp" #include "blockwise_gemm_dlops_v2r2.hpp"
#include "blockwise_dynamic_tensor_slice_transfer.hpp" #include "blockwise_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_set.hpp" #include "threadwise_tensor_slice_set.hpp"
namespace ck { namespace ck {
...@@ -26,7 +26,7 @@ __global__ void ...@@ -26,7 +26,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_dynamic_gemm_dlops_v1r2( kernel_gemm_dlops_v1r2(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
...@@ -68,8 +68,7 @@ __global__ void ...@@ -68,8 +68,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_dynamic_gemm_dlops_v1r2( kernel_gemm_dlops_v1r2(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const void CONSTANT* p_a_k_m0_m1_grid_desc, const void CONSTANT* p_a_k_m0_m1_grid_desc,
...@@ -80,16 +79,16 @@ __global__ void ...@@ -80,16 +79,16 @@ __global__ void
// first cast void CONSTANT void* to void* // first cast void CONSTANT void* to void*
// second cast void* to Desc* // second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4) // the copy constructor of tensor descriptor doesn't take address_space(4)
const auto a_k_m0_m1_grid_desc = const auto a_k_m0_m1_grid_desc = *reinterpret_cast<const AKM0M1GridDesc*>(
*reinterpret_cast<const AKM0M1GridDesc*>((const void*)p_a_k_m0_m1_grid_desc); cast_pointer_to_generic_address_space(p_a_k_m0_m1_grid_desc));
const auto b_k_n0_n1_grid_desc = const auto b_k_n0_n1_grid_desc = *reinterpret_cast<const BKN0N1GridDesc*>(
*reinterpret_cast<const BKN0N1GridDesc*>((const void*)p_b_k_n0_n1_grid_desc); cast_pointer_to_generic_address_space(p_b_k_n0_n1_grid_desc));
const auto c_m0_m10_m11_n0_n10_n11_grid_desc = const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
*reinterpret_cast<const CM0M10M11N0N10N11GridDesc*>( *reinterpret_cast<const CM0M10M11N0N10N11GridDesc*>(
(const void*)p_c_m0_m10_m11_n0_n10_n11_grid_desc); cast_pointer_to_generic_address_space(p_c_m0_m10_m11_n0_n10_n11_grid_desc));
const auto c_blockid_to_m0_n0_block_cluster_adaptor = const auto c_blockid_to_m0_n0_block_cluster_adaptor =
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>( *reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
(const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor); cast_pointer_to_generic_address_space(p_c_blockid_to_m0_n0_block_cluster_adaptor));
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
...@@ -146,12 +145,12 @@ template <index_t BlockSize, ...@@ -146,12 +145,12 @@ template <index_t BlockSize,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector, index_t CThreadTransferDstScalarPerVector,
typename AGridIteratorHacks, typename AGridStepHacks,
typename BGridIteratorHacks, typename BGridStepHacks,
typename CGridIteratorHacks, typename CGridStepHacks,
typename AGridMoveSliceWindowIteratorHacks, typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowIteratorHacks> typename BGridMoveSliceWindowStepHacks>
struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 struct GridwiseGemmDlops_km_kn_mn_v1r2
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -167,12 +166,12 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 ...@@ -167,12 +166,12 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}), max_lds_align); make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}), max_lds_align); make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}), max_lds_align);
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
...@@ -230,7 +229,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 ...@@ -230,7 +229,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
const auto M1 = Number<MPerBlockM1>{}; const auto M1 = Number<MPerBlockM1>{};
const auto M0 = M / M1; const auto M0 = M / M1;
const auto a_k_m0_m1_grid_desc = transform_dynamic_tensor_descriptor( const auto a_k_m0_m1_grid_desc = transform_tensor_descriptor(
a_k_m_grid_desc, a_k_m_grid_desc,
make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(M0, M1))), make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(M0, M1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
...@@ -248,7 +247,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 ...@@ -248,7 +247,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
const auto N1 = Number<NPerBlockN1>{}; const auto N1 = Number<NPerBlockN1>{};
const auto N0 = N / N1; const auto N0 = N / N1;
const auto b_k_n0_n1_grid_desc = transform_dynamic_tensor_descriptor( const auto b_k_n0_n1_grid_desc = transform_tensor_descriptor(
b_k_n_grid_desc, b_k_n_grid_desc,
make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(N0, N1))), make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(N0, N1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
...@@ -277,7 +276,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 ...@@ -277,7 +276,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
constexpr auto M10 = M1 / M11; constexpr auto M10 = M1 / M11;
constexpr auto N10 = N1 / N11; constexpr auto N10 = N1 / N11;
const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_dynamic_tensor_descriptor( const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_tensor_descriptor(
c_m_n_grid_desc, c_m_n_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)), make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
make_unmerge_transform(make_tuple(N0, N10, N11))), make_unmerge_transform(make_tuple(N0, N10, N11))),
...@@ -352,27 +351,27 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 ...@@ -352,27 +351,27 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}), max_lds_align); make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}), max_lds_align); make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}), max_lds_align);
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto a_k_m0_m1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto a_k_m0_m1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, I1, Number<MPerBlockM1>{}), max_lds_align); make_tuple(Number<KPerBlock>{}, I1, Number<MPerBlockM1>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto b_k_n0_n1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto b_k_n0_n1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, I1, Number<NPerBlockN1>{}), max_lds_align); make_tuple(Number<KPerBlock>{}, I1, Number<NPerBlockN1>{}), max_lds_align);
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<KPerBlock, 1, MPerBlockM1>, Sequence<KPerBlock, 1, MPerBlockM1>,
ABlockTransferThreadSliceLengths_K_M0_M1, ABlockTransferThreadSliceLengths_K_M0_M1,
...@@ -398,7 +397,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 ...@@ -398,7 +397,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = auto b_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<KPerBlock, 1, NPerBlockN1>, Sequence<KPerBlock, 1, NPerBlockN1>,
BBlockTransferThreadSliceLengths_K_N0_N1, BBlockTransferThreadSliceLengths_K_N0_N1,
...@@ -447,8 +446,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 ...@@ -447,8 +446,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths(); decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths();
constexpr auto c_m10_m11_n10_n11_thread_desc = constexpr auto c_m10_m11_n10_n11_thread_desc = make_naive_tensor_descriptor_packed(
make_dynamic_naive_tensor_descriptor_packed_v2(
sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths)); sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
...@@ -465,7 +463,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 ...@@ -465,7 +463,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>( auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>(
c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize()); c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize());
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc, ThreadwiseTensorSliceSet_v1<FloatAcc,
decltype(c_m10_m11_n10_n11_thread_desc), decltype(c_m10_m11_n10_n11_thread_desc),
decltype(c_m10_m11_n10_n11_thread_tensor_lengths)>{} decltype(c_m10_m11_n10_n11_thread_tensor_lengths)>{}
.Run(c_m10_m11_n10_n11_thread_desc, .Run(c_m10_m11_n10_n11_thread_desc,
...@@ -477,15 +475,15 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 ...@@ -477,15 +475,15 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy // hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k_m0_m1_global_iterator_hacks = AGridIteratorHacks{}; constexpr auto a_k_m0_m1_global_step_hacks = AGridStepHacks{};
constexpr auto b_k_n0_n1_global_iterator_hacks = BGridIteratorHacks{}; constexpr auto b_k_n0_n1_global_step_hacks = BGridStepHacks{};
// hack to control index calculation when move slice window for A and B matrix for // hack to control index calculation when move slice window for A and B matrix for
// threadwise copy // threadwise copy
constexpr auto a_k_m0_m1_global_move_slice_window_iterator_hack = constexpr auto a_k_m0_m1_global_move_slice_window_step_hack =
AGridMoveSliceWindowIteratorHacks{}; AGridMoveSliceWindowStepHacks{};
constexpr auto b_k_n0_n1_global_move_slice_window_iterator_hack = constexpr auto b_k_n0_n1_global_move_slice_window_step_hack =
BGridMoveSliceWindowIteratorHacks{}; BGridMoveSliceWindowStepHacks{};
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_a_block_double, a_k_m0_m1_block_desc.GetElementSpaceSize()); p_a_block_double, a_k_m0_m1_block_desc.GetElementSpaceSize());
...@@ -502,9 +500,9 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 ...@@ -502,9 +500,9 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks); a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks); b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_even_buf); a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_even_buf); b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_even_buf);
...@@ -519,22 +517,20 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 ...@@ -519,22 +517,20 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
do do
{ {
// even iteration // even iteration
a_blockwise_copy.MoveSrcSliceWindow( a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc,
a_k_m0_m1_grid_desc,
a_block_slice_copy_step, a_block_slice_copy_step,
a_k_m0_m1_global_move_slice_window_iterator_hack); a_k_m0_m1_global_move_slice_window_step_hack);
b_blockwise_copy.MoveSrcSliceWindow( b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc,
b_k_n0_n1_grid_desc,
b_block_slice_copy_step, b_block_slice_copy_step,
b_k_n0_n1_global_move_slice_window_iterator_hack); b_k_n0_n1_global_move_slice_window_step_hack);
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks); a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks); b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc, blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc,
...@@ -547,22 +543,20 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 ...@@ -547,22 +543,20 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_odd_buf); b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_odd_buf);
// odd iteration // odd iteration
a_blockwise_copy.MoveSrcSliceWindow( a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc,
a_k_m0_m1_grid_desc,
a_block_slice_copy_step, a_block_slice_copy_step,
a_k_m0_m1_global_move_slice_window_iterator_hack); a_k_m0_m1_global_move_slice_window_step_hack);
b_blockwise_copy.MoveSrcSliceWindow( b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc,
b_k_n0_n1_grid_desc,
b_block_slice_copy_step, b_block_slice_copy_step,
b_k_n0_n1_global_move_slice_window_iterator_hack); b_k_n0_n1_global_move_slice_window_step_hack);
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks); a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks); b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run( blockwise_gemm.Run(
...@@ -581,18 +575,18 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 ...@@ -581,18 +575,18 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
{ {
a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc, a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc,
a_block_slice_copy_step, a_block_slice_copy_step,
a_k_m0_m1_global_move_slice_window_iterator_hack); a_k_m0_m1_global_move_slice_window_step_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc, b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc,
b_block_slice_copy_step, b_block_slice_copy_step,
b_k_n0_n1_global_move_slice_window_iterator_hack); b_k_n0_n1_global_move_slice_window_step_hack);
__syncthreads(); __syncthreads();
// LDS double buffer: load last data from device mem // LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks); a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks); b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run( blockwise_gemm.Run(
...@@ -619,19 +613,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 ...@@ -619,19 +613,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
// output: register to global memory // output: register to global memory
{ {
constexpr index_t M11 =
M1PerThreadM111 * M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101;
constexpr index_t N11 =
N1PerThreadN111 * M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101;
constexpr index_t M10 = MPerBlockM1 / M11;
constexpr index_t N10 = NPerBlockN1 / N11;
constexpr index_t M111 = M1PerThreadM111;
constexpr index_t N111 = N1PerThreadN111;
constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc = constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2( make_naive_tensor_descriptor_packed(
make_tuple(I1, make_tuple(I1,
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{}, Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I1]>{}, Number<c_m10_m11_n10_n11_thread_tensor_lengths[I1]>{},
...@@ -642,7 +625,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 ...@@ -642,7 +625,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
const auto c_m10_m11_n10_n11_thread_origin_idx_on_block = const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
blockwise_gemm.CalculateCM0M1N0N1ThreadOriginOnBlock(get_thread_local_1d_id()); blockwise_gemm.CalculateCM0M1N0N1ThreadOriginOnBlock(get_thread_local_1d_id());
ThreadwiseDynamicTensorSliceTransfer_v1r3< ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc, FloatAcc,
FloatC, FloatC,
decltype(c_m0_m10_m11_n0_n10_n11_thread_desc), decltype(c_m0_m10_m11_n0_n10_n11_thread_desc),
...@@ -670,7 +653,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 ...@@ -670,7 +653,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
c_thread_buf, c_thread_buf,
c_m0_m10_m11_n0_n10_n11_grid_desc, c_m0_m10_m11_n0_n10_n11_grid_desc,
c_grid_buf, c_grid_buf,
CGridIteratorHacks{}); CGridStepHacks{});
} }
} }
}; };
......
#ifndef CK_GRIDWISE_DYNAMIC_GEMM_V1R3_HPP #ifndef CK_GRIDWISE_GEMM_V1R3_HPP
#define CK_GRIDWISE_DYNAMIC_GEMM_V1R3_HPP #define CK_GRIDWISE_GEMM_V1R3_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_multi_index_transform_helper.hpp" #include "multi_index_transform_helper.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_dlops_v2r3.hpp" #include "blockwise_gemm_dlops_v2r3.hpp"
#include "blockwise_dynamic_tensor_slice_transfer_v2.hpp" #include "blockwise_tensor_slice_transfer_v2.hpp"
#include "threadwise_dynamic_tensor_slice_transfer_v2.hpp" #include "threadwise_tensor_slice_transfer_v2.hpp"
#include "threadwise_dynamic_tensor_slice_set.hpp" #include "threadwise_tensor_slice_set.hpp"
namespace ck { namespace ck {
...@@ -26,7 +26,7 @@ __global__ void ...@@ -26,7 +26,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_dynamic_gemm_dlops_v1r3( kernel_gemm_dlops_v1r3(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
...@@ -68,8 +68,7 @@ __global__ void ...@@ -68,8 +68,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_dynamic_gemm_dlops_v1r3( kernel_gemm_dlops_v1r3(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const void CONSTANT* p_a_k0_m0_m1_k1_grid_desc, const void CONSTANT* p_a_k0_m0_m1_k1_grid_desc,
...@@ -80,16 +79,16 @@ __global__ void ...@@ -80,16 +79,16 @@ __global__ void
// first cast void CONSTANT void* to void* // first cast void CONSTANT void* to void*
// second cast void* to Desc* // second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4) // the copy constructor of tensor descriptor doesn't take address_space(4)
const auto a_k0_m0_m1_k1_grid_desc = const auto a_k0_m0_m1_k1_grid_desc = *reinterpret_cast<const AK0M0M1K1GridDesc*>(
*reinterpret_cast<const AK0M0M1K1GridDesc*>((const void*)p_a_k0_m0_m1_k1_grid_desc); cast_pointer_to_generic_address_space(p_a_k0_m0_m1_k1_grid_desc));
const auto b_k0_n0_n1_k1_grid_desc = const auto b_k0_n0_n1_k1_grid_desc = *reinterpret_cast<const BK0N0N1K1GridDesc*>(
*reinterpret_cast<const BK0N0N1K1GridDesc*>((const void*)p_b_k0_n0_n1_k1_grid_desc); cast_pointer_to_generic_address_space(p_b_k0_n0_n1_k1_grid_desc));
const auto c_m0_m10_m11_n0_n10_n11_grid_desc = const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
*reinterpret_cast<const CM0M10M11N0N10N11GridDesc*>( *reinterpret_cast<const CM0M10M11N0N10N11GridDesc*>(
(const void*)p_c_m0_m10_m11_n0_n10_n11_grid_desc); cast_pointer_to_generic_address_space(p_c_m0_m10_m11_n0_n10_n11_grid_desc));
const auto c_blockid_to_m0_n0_block_cluster_adaptor = const auto c_blockid_to_m0_n0_block_cluster_adaptor =
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>( *reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
(const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor); cast_pointer_to_generic_address_space(p_c_blockid_to_m0_n0_block_cluster_adaptor));
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
...@@ -142,12 +141,12 @@ template <index_t BlockSize, ...@@ -142,12 +141,12 @@ template <index_t BlockSize,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector, index_t CThreadTransferDstScalarPerVector,
typename AGridIteratorHacks, typename AGridStepHacks,
typename BGridIteratorHacks, typename BGridStepHacks,
typename CGridIteratorHacks, typename CGridStepHacks,
typename AGridMoveSliceWindowIteratorHacks, typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowIteratorHacks> typename BGridMoveSliceWindowStepHacks>
struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 struct GridwiseGemmDlops_km_kn_mn_v1r3
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -164,12 +163,12 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 ...@@ -164,12 +163,12 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
// TODO: check alignment // TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}, K1), max_lds_align); make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}, K1), max_lds_align);
// TODO: check alignment // TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}, K1), max_lds_align); make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}, K1), max_lds_align);
// TODO: check alignment // TODO: check alignment
...@@ -191,12 +190,12 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 ...@@ -191,12 +190,12 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
const auto M = a_k0_m_k1_grid_desc.GetLength(I1); const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
const auto N = b_k0_n_k1_grid_desc.GetLength(I1); const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
const auto K1 = a_k0_m_k1_grid_desc.GetLength(I2);
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
K0 == b_k0_n_k1_grid_desc.GetLength(I0) && K0 == b_k0_n_k1_grid_desc.GetLength(I0) &&
K1 == a_k0_m_k1_grid_desc.GetLength(I2) &&
K1 == b_k0_n_k1_grid_desc.GetLength(I2)) && K1 == b_k0_n_k1_grid_desc.GetLength(I2)) &&
(M % MPerBlockM1 == 0 && N % NPerBlockN1 == 0 && K0 % KPerBlock == 0); (M % MPerBlockM1 == 0 && N % NPerBlockN1 == 0 && K0 % KPerBlock == 0);
} }
...@@ -231,8 +230,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 ...@@ -231,8 +230,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
const auto M1 = Number<MPerBlockM1>{}; const auto M1 = Number<MPerBlockM1>{};
const auto M0 = M / M1; const auto M0 = M / M1;
const auto a_k0_m0_m1_k1_grid_desc = transform_dynamic_tensor_descriptor( const auto a_k0_m0_m1_k1_grid_desc =
a_k0_m_k1_grid_desc, transform_tensor_descriptor(a_k0_m_k1_grid_desc,
make_tuple(make_pass_through_transform(K0), make_tuple(make_pass_through_transform(K0),
make_unmerge_transform(make_tuple(M0, M1)), make_unmerge_transform(make_tuple(M0, M1)),
make_pass_through_transform(K1)), make_pass_through_transform(K1)),
...@@ -251,8 +250,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 ...@@ -251,8 +250,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
const auto N1 = Number<NPerBlockN1>{}; const auto N1 = Number<NPerBlockN1>{};
const auto N0 = N / N1; const auto N0 = N / N1;
const auto b_k0_n0_n1_k1_grid_desc = transform_dynamic_tensor_descriptor( const auto b_k0_n0_n1_k1_grid_desc =
b_k0_n_k1_grid_desc, transform_tensor_descriptor(b_k0_n_k1_grid_desc,
make_tuple(make_pass_through_transform(K0), make_tuple(make_pass_through_transform(K0),
make_unmerge_transform(make_tuple(N0, N1)), make_unmerge_transform(make_tuple(N0, N1)),
make_pass_through_transform(K1)), make_pass_through_transform(K1)),
...@@ -275,16 +274,16 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 ...@@ -275,16 +274,16 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
const auto N0 = N / N1; const auto N0 = N / N1;
constexpr auto M11 = constexpr auto M11 =
Number<container_reduce(M11N11ThreadClusterM110Xs{}, math::multiplies_v2{}, I1) * Number<container_reduce(M11N11ThreadClusterM110Xs{}, math::multiplies{}, I1) *
M1PerThreadM111>{}; M1PerThreadM111>{};
constexpr auto N11 = constexpr auto N11 =
Number<container_reduce(M11N11ThreadClusterN110Xs{}, math::multiplies_v2{}, I1) * Number<container_reduce(M11N11ThreadClusterN110Xs{}, math::multiplies{}, I1) *
N1PerThreadN111>{}; N1PerThreadN111>{};
constexpr auto M10 = M1 / M11; constexpr auto M10 = M1 / M11;
constexpr auto N10 = N1 / N11; constexpr auto N10 = N1 / N11;
const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_dynamic_tensor_descriptor( const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_tensor_descriptor(
c_m_n_grid_desc, c_m_n_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)), make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
make_unmerge_transform(make_tuple(N0, N10, N11))), make_unmerge_transform(make_tuple(N0, N10, N11))),
...@@ -355,23 +354,23 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 ...@@ -355,23 +354,23 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
// TODO: check alignment // TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto a_k0_m0_m1_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto a_k0_m0_m1_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, I1, Number<MPerBlockM1>{}, K1), max_lds_align); make_tuple(Number<KPerBlock>{}, I1, Number<MPerBlockM1>{}, K1), max_lds_align);
// TODO: check alignment // TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto b_k0_n0_n1_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto b_k0_n0_n1_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, I1, Number<NPerBlockN1>{}, K1), max_lds_align); make_tuple(Number<KPerBlock>{}, I1, Number<NPerBlockN1>{}, K1), max_lds_align);
// TODO: check alignment // TODO: check alignment
// A matrix in LDS memory, for blockwise GEMM // A matrix in LDS memory, for blockwise GEMM
constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}, K1), max_lds_align); make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}, K1), max_lds_align);
// TODO: check alignment // TODO: check alignment
// B matrix in LDS memory, for blockwise GEMM // B matrix in LDS memory, for blockwise GEMM
constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}, K1), max_lds_align); make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}, K1), max_lds_align);
static_assert(a_k0_m0_m1_k1_block_desc.GetElementSpaceSize() == static_assert(a_k0_m0_m1_k1_block_desc.GetElementSpaceSize() ==
...@@ -381,7 +380,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 ...@@ -381,7 +380,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
"wrong!"); "wrong!");
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1< auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1<
BlockSize, BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<KPerBlock, 1, MPerBlockM1, K1.value>, Sequence<KPerBlock, 1, MPerBlockM1, K1.value>,
...@@ -405,7 +404,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 ...@@ -405,7 +404,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
make_multi_index(0, 0, 0, 0)); make_multi_index(0, 0, 0, 0));
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1< auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1<
BlockSize, BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<KPerBlock, 1, NPerBlockN1, K1.value>, Sequence<KPerBlock, 1, NPerBlockN1, K1.value>,
...@@ -453,8 +452,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 ...@@ -453,8 +452,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1(); decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
constexpr auto c_m10_m11_n10_n11_thread_desc = constexpr auto c_m10_m11_n10_n11_thread_desc = make_naive_tensor_descriptor_packed(
make_dynamic_naive_tensor_descriptor_packed_v2(
sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths)); sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
...@@ -471,7 +469,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 ...@@ -471,7 +469,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>( auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>(
c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize()); c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize());
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc, ThreadwiseTensorSliceSet_v1<FloatAcc,
decltype(c_m10_m11_n10_n11_thread_desc), decltype(c_m10_m11_n10_n11_thread_desc),
decltype(c_m10_m11_n10_n11_thread_tensor_lengths)>{} decltype(c_m10_m11_n10_n11_thread_tensor_lengths)>{}
.Run(c_m10_m11_n10_n11_thread_desc, .Run(c_m10_m11_n10_n11_thread_desc,
...@@ -496,8 +494,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 ...@@ -496,8 +494,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{}); a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{}); b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf); a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf); b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf);
...@@ -516,18 +514,16 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 ...@@ -516,18 +514,16 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
// even iteration // even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc,
a_block_slice_copy_step, a_block_slice_copy_step,
AGridMoveSliceWindowIteratorHacks{}); AGridMoveSliceWindowStepHacks{});
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc,
b_block_slice_copy_step, b_block_slice_copy_step,
BGridMoveSliceWindowIteratorHacks{}); BGridMoveSliceWindowStepHacks{});
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{}); b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
b_blockwise_copy.RunRead(
b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{});
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc, blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc,
...@@ -542,18 +538,16 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 ...@@ -542,18 +538,16 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
// odd iteration // odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc,
a_block_slice_copy_step, a_block_slice_copy_step,
AGridMoveSliceWindowIteratorHacks{}); AGridMoveSliceWindowStepHacks{});
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc,
b_block_slice_copy_step, b_block_slice_copy_step,
BGridMoveSliceWindowIteratorHacks{}); BGridMoveSliceWindowStepHacks{});
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{}); b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
b_blockwise_copy.RunRead(
b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{});
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run( blockwise_gemm.Run(
...@@ -570,18 +564,16 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 ...@@ -570,18 +564,16 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
// LDS double buffer: tail // LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{ {
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, a_blockwise_copy.MoveSrcSliceWindow(
a_block_slice_copy_step, a_k0_m0_m1_k1_grid_desc, a_block_slice_copy_step, AGridMoveSliceWindowStepHacks{});
AGridMoveSliceWindowIteratorHacks{}); b_blockwise_copy.MoveSrcSliceWindow(
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, b_k0_n0_n1_k1_grid_desc, b_block_slice_copy_step, BGridMoveSliceWindowStepHacks{});
b_block_slice_copy_step,
BGridMoveSliceWindowIteratorHacks{});
__syncthreads(); __syncthreads();
// LDS double buffer: load last data from device mem // LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{}); a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{}); b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run( blockwise_gemm.Run(
...@@ -608,21 +600,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 ...@@ -608,21 +600,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
// output: register to global memory // output: register to global memory
{ {
constexpr auto M11 =
Number<container_reduce(M11N11ThreadClusterM110Xs{}, math::multiplies_v2{}, I1) *
M1PerThreadM111>{};
constexpr auto N11 =
Number<container_reduce(M11N11ThreadClusterN110Xs{}, math::multiplies_v2{}, I1) *
N1PerThreadN111>{};
constexpr index_t M10 = MPerBlockM1 / M11;
constexpr index_t N10 = NPerBlockN1 / N11;
constexpr index_t M111 = M1PerThreadM111;
constexpr index_t N111 = N1PerThreadN111;
constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc = constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2( make_naive_tensor_descriptor_packed(
make_tuple(I1, make_tuple(I1,
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{}, Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I1]>{}, Number<c_m10_m11_n10_n11_thread_tensor_lengths[I1]>{},
...@@ -634,7 +613,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 ...@@ -634,7 +613,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
get_thread_local_1d_id()); get_thread_local_1d_id());
ThreadwiseDynamicTensorSliceTransfer_v1r3< ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc, FloatAcc,
FloatC, FloatC,
decltype(c_m0_m10_m11_n0_n10_n11_thread_desc), decltype(c_m0_m10_m11_n0_n10_n11_thread_desc),
...@@ -662,7 +641,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 ...@@ -662,7 +641,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
c_thread_buf, c_thread_buf,
c_m0_m10_m11_n0_n10_n11_grid_desc, c_m0_m10_m11_n0_n10_n11_grid_desc,
c_grid_buf, c_grid_buf,
CGridIteratorHacks{}); CGridStepHacks{});
} }
} }
}; };
......
#ifndef CK_GRIDWISE_DYNAMIC_GEMM_V2_HPP #ifndef CK_GRIDWISE_GEMM_V2_HPP
#define CK_GRIDWISE_DYNAMIC_GEMM_V2_HPP #define CK_GRIDWISE_GEMM_V2_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_multi_index_transform_helper.hpp" #include "multi_index_transform_helper.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "blockwise_dynamic_tensor_slice_transfer.hpp" #include "blockwise_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "blockwise_gemm_dlops_v3.hpp" #include "blockwise_gemm_dlops_v3.hpp"
namespace ck { namespace ck {
...@@ -42,12 +42,12 @@ template <index_t BlockSize, ...@@ -42,12 +42,12 @@ template <index_t BlockSize,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector, index_t CThreadTransferDstScalarPerVector,
typename AGlobalIteratorHacks, typename AGlobalStepHacks,
typename BGlobalIteratorHacks, typename BGlobalStepHacks,
typename CGlobalIteratorHacks, typename CGlobalStepHacks,
typename AGlobalMoveSliceWindowIteratorHacks, typename AGlobalMoveSliceWindowStepHacks,
typename BGlobalMoveSliceWindowIteratorHacks> typename BGlobalMoveSliceWindowStepHacks>
struct GridwiseDynamicGemmDlops_km_kn_mn_v3 struct GridwiseGemmDlops_km_kn_mn_v3
{ {
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
...@@ -58,7 +58,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3 ...@@ -58,7 +58,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto a_e_k_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto a_e_k_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<E>{}, Number<KPerBlock>{}), max_lds_align); make_tuple(Number<E>{}, Number<KPerBlock>{}), max_lds_align);
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
...@@ -102,7 +102,6 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3 ...@@ -102,7 +102,6 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
// divide block work by [M, N] // divide block work by [M, N]
#if 0 #if 0
const auto k_block_work_num = K / Number<KPerBlock>{};
const auto ho_block_work_num = Ho / Number<HoPerBlock>{}; const auto ho_block_work_num = Ho / Number<HoPerBlock>{};
const auto wo_block_work_num = Wo / Number<WoPerBlock>{}; const auto wo_block_work_num = Wo / Number<WoPerBlock>{};
const auto hwo_block_work_num = ho_block_work_num * wo_block_work_num; const auto hwo_block_work_num = ho_block_work_num * wo_block_work_num;
...@@ -114,7 +113,6 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3 ...@@ -114,7 +113,6 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num; const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num;
#else #else
// Hack: this force result into SGPR // Hack: this force result into SGPR
const index_t k_block_work_num = __builtin_amdgcn_readfirstlane(K / KPerBlock);
const index_t ho_block_work_num = __builtin_amdgcn_readfirstlane(Ho / HoPerBlock); const index_t ho_block_work_num = __builtin_amdgcn_readfirstlane(Ho / HoPerBlock);
const index_t wo_block_work_num = __builtin_amdgcn_readfirstlane(Wo / WoPerBlock); const index_t wo_block_work_num = __builtin_amdgcn_readfirstlane(Wo / WoPerBlock);
const index_t hwo_block_work_num = ho_block_work_num * wo_block_work_num; const index_t hwo_block_work_num = ho_block_work_num * wo_block_work_num;
...@@ -134,22 +132,20 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3 ...@@ -134,22 +132,20 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto a_e_k_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto a_e_k_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<EPerBlock>{}, Number<KPerBlock>{}), max_lds_align); make_tuple(Number<EPerBlock>{}, Number<KPerBlock>{}), max_lds_align);
constexpr auto a_e_k_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto a_e_k_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<E>{}, Number<KPerBlock>{}), max_lds_align); make_tuple(Number<E>{}, Number<KPerBlock>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto b_e_n_ho_wo_block_desc = constexpr auto b_e_n_ho_wo_block_desc = make_naive_tensor_descriptor_packed(make_tuple(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<EPerBlock>{}, Number<1>{}, Number<HoPerBlock>{}, Number<WoPerBlock>{})); Number<EPerBlock>{}, Number<1>{}, Number<HoPerBlock>{}, Number<WoPerBlock>{}));
// c_thread_mtx definition: this is a mess // c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx // TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_k_n_ho_wo_thread_desc = constexpr auto c_k_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<KPerThread>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{})); Number<KPerThread>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
auto blockwise_gemm = auto blockwise_gemm =
...@@ -184,7 +180,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3 ...@@ -184,7 +180,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<E, KPerBlock>, Sequence<E, KPerBlock>,
ABlockTransferThreadSliceLengths_E_K, ABlockTransferThreadSliceLengths_E_K,
...@@ -203,18 +199,16 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3 ...@@ -203,18 +199,16 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
1, 1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true>( true>(a_e_k_global_desc,
a_e_k_global_desc,
make_multi_index(0, k_block_data_on_global), make_multi_index(0, k_block_data_on_global),
a_e_k_desc, a_e_k_desc,
make_multi_index(0, 0)); make_multi_index(0, 0));
constexpr auto b_e_n_ho_wo_thread_desc = constexpr auto b_e_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<EPerBlock>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{})); Number<EPerBlock>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
auto b_threadwise_transfer = ThreadwiseDynamicTensorSliceTransfer_v2< auto b_threadwise_transfer =
FloatAB, ThreadwiseTensorSliceTransfer_v2<FloatAB,
FloatAB, FloatAB,
decltype(b_e_n_ho_wo_global_desc), decltype(b_e_n_ho_wo_global_desc),
decltype(b_e_n_ho_wo_thread_desc), decltype(b_e_n_ho_wo_thread_desc),
...@@ -223,7 +217,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3 ...@@ -223,7 +217,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
1, 1,
true>(b_e_n_ho_wo_global_desc, true>(
b_e_n_ho_wo_global_desc,
make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global)); make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global));
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
...@@ -232,11 +227,12 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3 ...@@ -232,11 +227,12 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
// register allocation for output // register allocation for output
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum_t::Vgpr,
FloatAcc, FloatAcc,
c_k_n_ho_wo_thread_desc.GetElementSpaceSize()> c_k_n_ho_wo_thread_desc.GetElementSpaceSize(),
true>
c_thread_buf; c_thread_buf;
// initialize output thread tensor // initialize output thread tensor
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc, ThreadwiseTensorSliceSet_v1<FloatAcc,
decltype(c_k_n_ho_wo_thread_desc), decltype(c_k_n_ho_wo_thread_desc),
Sequence<KPerThread, 1, HoPerThread, WoPerThread>>{} Sequence<KPerThread, 1, HoPerThread, WoPerThread>>{}
.Run(c_k_n_ho_wo_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0}); .Run(c_k_n_ho_wo_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
...@@ -244,32 +240,32 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3 ...@@ -244,32 +240,32 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0); constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy // hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_e_k_global_iterator_hacks = AGlobalIteratorHacks{}; constexpr auto a_e_k_global_step_hacks = AGlobalStepHacks{};
constexpr auto b_e_n_ho_wo_global_iterator_hacks = BGlobalIteratorHacks{}; constexpr auto b_e_n_ho_wo_global_step_hacks = BGlobalStepHacks{};
// hack to control index calculation when move slice window for A and B matrix for // hack to control index calculation when move slice window for A and B matrix for
// threadwise copy // threadwise copy
constexpr auto a_e_k_global_move_slice_window_iterator_hack = constexpr auto a_e_k_global_move_slice_window_step_hack = AGlobalMoveSliceWindowStepHacks{};
AGlobalMoveSliceWindowIteratorHacks{}; constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack =
constexpr auto b_e_n_ho_wo_global_move_slice_window_iterator_hack = BGlobalMoveSliceWindowStepHacks{};
BGlobalMoveSliceWindowIteratorHacks{};
// double regsiter buffer for b // double regsiter buffer for b
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum_t::Vgpr,
FloatAB, FloatAB,
b_e_n_ho_wo_thread_desc.GetElementSpaceSize()> b_e_n_ho_wo_thread_desc.GetElementSpaceSize(),
true>
b_thread_even_buf, b_thread_odd_buf; b_thread_even_buf, b_thread_odd_buf;
// LDS double buffer: preload data // LDS double buffer: preload data
{ {
a_blockwise_copy.RunRead(a_e_k_global_desc, a_global_buf, a_e_k_global_iterator_hacks); a_blockwise_copy.RunRead(a_e_k_global_desc, a_global_buf, a_e_k_global_step_hacks);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
b_global_buf, b_global_buf,
b_e_n_ho_wo_thread_desc, b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_thread_even_buf, b_thread_even_buf,
b_e_n_ho_wo_global_iterator_hacks); b_e_n_ho_wo_global_step_hacks);
a_blockwise_copy.RunWrite(a_e_k_desc, a_block_buf); a_blockwise_copy.RunWrite(a_e_k_desc, a_block_buf);
} }
...@@ -293,7 +289,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3 ...@@ -293,7 +289,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
b_e_n_ho_wo_thread_desc, b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_thread_odd_buf, b_thread_odd_buf,
b_e_n_ho_wo_global_iterator_hacks); b_e_n_ho_wo_global_step_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
// TODO: @Zhang Jing: blockwise gemm should be able to move slice window // TODO: @Zhang Jing: blockwise gemm should be able to move slice window
...@@ -309,7 +305,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3 ...@@ -309,7 +305,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
b_e_n_ho_wo_thread_desc, b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_thread_even_buf, b_thread_even_buf,
b_e_n_ho_wo_global_iterator_hacks); b_e_n_ho_wo_global_step_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
...@@ -332,7 +328,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3 ...@@ -332,7 +328,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
b_e_n_ho_wo_thread_desc, b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_thread_odd_buf, b_thread_odd_buf,
b_e_n_ho_wo_global_iterator_hacks); b_e_n_ho_wo_global_step_hacks);
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
...@@ -351,13 +347,12 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3 ...@@ -351,13 +347,12 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
// output: register to global memory // output: register to global memory
{ {
// hack to control index calculation when iterating over c_k_n_ho_wo_global tensor // hack to control index calculation when iterating over c_k_n_ho_wo_global tensor
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{}; constexpr auto c_k_n_ho_wo_global_tensor_step_hacks = CGlobalStepHacks{};
const index_t k_thread_data_on_global = const index_t k_thread_data_on_global =
k_block_data_on_global + k_thread_id * KPerThread; k_block_data_on_global + k_thread_id * KPerThread;
ThreadwiseDynamicTensorSliceTransfer_v1r3< ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
FloatAcc,
FloatC, FloatC,
decltype(c_k_n_ho_wo_thread_desc), decltype(c_k_n_ho_wo_thread_desc),
decltype(c_k_n_ho_wo_global_desc), decltype(c_k_n_ho_wo_global_desc),
...@@ -376,7 +371,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3 ...@@ -376,7 +371,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
c_thread_buf, c_thread_buf,
c_k_n_ho_wo_global_desc, c_k_n_ho_wo_global_desc,
c_global_buf, c_global_buf,
c_k_n_ho_wo_global_tensor_iterator_hacks); c_k_n_ho_wo_global_tensor_step_hacks);
} }
} }
......
...@@ -21,7 +21,7 @@ template <typename FloatA, ...@@ -21,7 +21,7 @@ template <typename FloatA,
typename TKLengths, typename TKLengths,
typename TMLengths, typename TMLengths,
typename TNLengths, typename TNLengths,
typename std::enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && typename enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(), CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
...@@ -97,8 +97,7 @@ struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1 ...@@ -97,8 +97,7 @@ struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1
CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset( CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset(
c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1)); c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1));
amd_inner_product_dlop<FloatA, FloatB, FloatC>( inner_product<FloatA, FloatB, FloatC>(a_buf[Number<a_offset>{}],
a_buf[Number<a_offset>{}],
b_buf[Number<b_offset>{}], b_buf[Number<b_offset>{}],
c_buf(Number<c_offset>{})); c_buf(Number<c_offset>{}));
}); });
...@@ -124,7 +123,7 @@ template <typename FloatA, ...@@ -124,7 +123,7 @@ template <typename FloatA,
typename TKLengths, typename TKLengths,
typename TMLengths, typename TMLengths,
typename TNLengths, typename TNLengths,
typename std::enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && typename enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(), CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
...@@ -214,7 +213,7 @@ struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_ ...@@ -214,7 +213,7 @@ struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_
CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset( CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset(
c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1)); c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1));
amd_inner_product_dlop<a_vector_t, b_vector_t, FloatC>( inner_product<a_vector_t, b_vector_t, FloatC>(
a_vec.template AsType<a_vector_t>()[I0], a_vec.template AsType<a_vector_t>()[I0],
b_vec.template AsType<b_vector_t>()[I0], b_vec.template AsType<b_vector_t>()[I0],
c_buf(Number<c_offset>{})); c_buf(Number<c_offset>{}));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment