"examples/vscode:/vscode.git/clone" did not exist on "bbdbd5820eeb1d2e8dd51afbe0e8f7c7f31838cd"
Unverified Commit f63a23ac authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

[MIOpen Downstream] Initial MIOpen integration (#52)

* update online kernel wrapper bundle all descriptors in a tuple

* change __CONSTANT__ to CONSTANT

* rename

* adding tuning

* added IsValidCompileParameter

* reorginze

* adding tunable for fp16 and int8

* fix kernel compile warning and bug fixes

* suppress warning about cast CONSTANT (address space 4) pointer

* fix building issue
parent 12649254
...@@ -7,9 +7,12 @@ ...@@ -7,9 +7,12 @@
namespace ck { namespace ck {
// GemmM = K // GemmM0 = 1
// GemmN = N * Ho * Wo // GemmM1 = K
// GemmK = C * Y * X // GemmN0 = N0
// GemmN1 = (N / N0) * Ho * Wo
// GemmK0 = (C / C0) * Y * X
// GemmK1 = C0
template <typename... Wei, template <typename... Wei,
typename... In, typename... In,
typename... Out, typename... Out,
......
...@@ -46,7 +46,7 @@ struct DynamicPassThrough ...@@ -46,7 +46,7 @@ struct DynamicPassThrough
__host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up, const UpIdxDiff& idx_diff_up,
LowIdx& idx_low, LowIdx& idx_low,
const UpIdx& idx_up_new, const UpIdx&,
Number<Hack>) Number<Hack>)
{ {
static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
...@@ -136,7 +136,7 @@ struct DynamicPad ...@@ -136,7 +136,7 @@ struct DynamicPad
__host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up, const UpIdxDiff& idx_diff_up,
LowIdx& idx_low, LowIdx& idx_low,
const UpIdx& idx_up_new, const UpIdx&,
Number<Hack>) Number<Hack>)
{ {
static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
...@@ -227,7 +227,7 @@ struct DynamicLeftPad ...@@ -227,7 +227,7 @@ struct DynamicLeftPad
__host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up, const UpIdxDiff& idx_diff_up,
LowIdx& idx_low, LowIdx& idx_low,
const UpIdx& idx_up_new, const UpIdx&,
Number<Hack>) Number<Hack>)
{ {
static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
...@@ -318,7 +318,7 @@ struct DynamicRightPad ...@@ -318,7 +318,7 @@ struct DynamicRightPad
__host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up, const UpIdxDiff& idx_diff_up,
LowIdx& idx_low, LowIdx& idx_low,
const UpIdx& idx_up_new, const UpIdx&,
Number<Hack>) Number<Hack>)
{ {
static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
...@@ -420,7 +420,7 @@ struct DynamicEmbed ...@@ -420,7 +420,7 @@ struct DynamicEmbed
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up, const UpIdxDiff& idx_diff_up,
LowIdx& idx_low, LowIdx& idx_low,
const UpIdx& idx_up_new, const UpIdx&,
Number<Hack>) const Number<Hack>) const
{ {
static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == NDimUp && static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == NDimUp &&
...@@ -1096,7 +1096,7 @@ struct DynamicMerge_v2_magic_division ...@@ -1096,7 +1096,7 @@ struct DynamicMerge_v2_magic_division
typename UpIdx, typename UpIdx,
index_t Hack> index_t Hack>
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up, const UpIdxDiff&,
LowIdx& idx_low, LowIdx& idx_low,
const UpIdx& idx_up_new, const UpIdx& idx_up_new,
Number<Hack>) const Number<Hack>) const
...@@ -1254,7 +1254,7 @@ struct DynamicMerge_v2r2_magic_division ...@@ -1254,7 +1254,7 @@ struct DynamicMerge_v2r2_magic_division
typename UpIdx, typename UpIdx,
index_t Hack> index_t Hack>
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up, const UpIdxDiff&,
LowIdx& idx_low, LowIdx& idx_low,
const UpIdx& idx_up_new, const UpIdx& idx_up_new,
Number<Hack>) const Number<Hack>) const
...@@ -1383,7 +1383,7 @@ struct DynamicUnMerge ...@@ -1383,7 +1383,7 @@ struct DynamicUnMerge
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up, const UpIdxDiff& idx_diff_up,
LowIdx& idx_low, LowIdx& idx_low,
const UpIdx& idx_up_new, const UpIdx&,
Number<Hack>) const Number<Hack>) const
{ {
CalculateLowerIndex(idx_diff_low, idx_diff_up); CalculateLowerIndex(idx_diff_low, idx_diff_up);
...@@ -1597,7 +1597,7 @@ struct DynamicVectorize ...@@ -1597,7 +1597,7 @@ struct DynamicVectorize
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up, const UpIdxDiff& idx_diff_up,
LowIdx& idx_low, LowIdx& idx_low,
const UpIdx& idx_up_new, const UpIdx&,
Number<Hack>) const Number<Hack>) const
{ {
static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
...@@ -1654,7 +1654,7 @@ struct DynamicSlice ...@@ -1654,7 +1654,7 @@ struct DynamicSlice
__host__ __device__ constexpr DynamicSlice() = default; __host__ __device__ constexpr DynamicSlice() = default;
__host__ __device__ constexpr DynamicSlice(const LowLength& low_length, __host__ __device__ constexpr DynamicSlice(const LowLength&,
const SliceBegin& slice_begin, const SliceBegin& slice_begin,
const SliceEnd& slice_end) const SliceEnd& slice_end)
: up_lengths_{make_tuple(slice_end - slice_begin)}, : up_lengths_{make_tuple(slice_end - slice_begin)},
...@@ -1687,7 +1687,7 @@ struct DynamicSlice ...@@ -1687,7 +1687,7 @@ struct DynamicSlice
__host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up, const UpIdxDiff& idx_diff_up,
LowIdx& idx_low, LowIdx& idx_low,
const UpIdx& idx_up_new, const UpIdx&,
Number<Hack>) Number<Hack>)
{ {
static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
...@@ -1709,8 +1709,7 @@ struct DynamicSlice ...@@ -1709,8 +1709,7 @@ struct DynamicSlice
} }
template <typename UpIdx> template <typename UpIdx>
__host__ __device__ constexpr bool __host__ __device__ constexpr bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx&) const
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const
{ {
return true; return true;
} }
......
...@@ -317,7 +317,7 @@ transform_dynamic_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc, ...@@ -317,7 +317,7 @@ transform_dynamic_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
NewUpperDimensionNewVisibleIdss{}); NewUpperDimensionNewVisibleIdss{});
static_assert(is_valid_sequence_map<decltype(all_old_top_ids)>::value && static_assert(is_valid_sequence_map<decltype(all_old_top_ids)>::value &&
is_valid_sequence_map<decltype(all_old_top_ids)>::value, is_valid_sequence_map<decltype(all_new_top_ids)>::value,
"wrong!"); "wrong!");
} }
...@@ -395,7 +395,6 @@ __host__ __device__ constexpr auto make_dynamic_tensor_coordinate(const TensorDe ...@@ -395,7 +395,6 @@ __host__ __device__ constexpr auto make_dynamic_tensor_coordinate(const TensorDe
constexpr index_t ntransform = TensorDesc::GetNumOfTransform(); constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension(); constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension();
constexpr index_t ndim_visible = TensorDesc::GetNumOfVisibleDimension();
constexpr auto visible_dim_ids = TensorDesc::GetVisibleDimensionIds(); constexpr auto visible_dim_ids = TensorDesc::GetVisibleDimensionIds();
MultiIndex<ndim_hidden> idx_hidden; MultiIndex<ndim_hidden> idx_hidden;
...@@ -492,11 +491,8 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate( ...@@ -492,11 +491,8 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate(
const TensorDesc& tensor_desc, TensorCoord& coord, const TensorCoordIterator& coord_iterator) const TensorDesc& tensor_desc, TensorCoord& coord, const TensorCoordIterator& coord_iterator)
{ {
constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension(); constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension();
constexpr index_t ndim_visible = TensorDesc::GetNumOfVisibleDimension();
constexpr index_t ntransform = TensorDesc::GetNumOfTransform(); constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
using HiddenIndex = MultiIndex<ndim_hidden>;
// this is what needs to be calculated // this is what needs to be calculated
auto idx_diff_hidden = make_zero_multi_index<ndim_hidden>(); auto idx_diff_hidden = make_zero_multi_index<ndim_hidden>();
......
...@@ -236,15 +236,15 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a ...@@ -236,15 +236,15 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
// shift // shift
constexpr index_t adaptor0_max_hidden_id = [&]() { constexpr index_t adaptor0_max_hidden_id = [&]() {
index_t adaptor0_max_hidden_id = NumericLimits<index_t>::Min(); index_t adaptor0_max_hidden_id_ = NumericLimits<index_t>::Min();
static_for<0, TensorAdaptor0::GetNumOfTransform(), 1>{}([&](auto itran) { static_for<0, TensorAdaptor0::GetNumOfTransform(), 1>{}([&](auto itran) {
constexpr index_t ndim_low = constexpr index_t ndim_low =
TensorAdaptor0{}.GetTransforms()[itran].GetNumOfLowerDimension(); TensorAdaptor0{}.GetTransforms()[itran].GetNumOfLowerDimension();
static_for<0, ndim_low, 1>{}([&](auto idim_low) { static_for<0, ndim_low, 1>{}([&](auto idim_low) {
adaptor0_max_hidden_id = adaptor0_max_hidden_id_ =
math::max(adaptor0_max_hidden_id, math::max(adaptor0_max_hidden_id_,
TensorAdaptor0::GetLowerDimensionHiddenIdss()[itran][idim_low].value); TensorAdaptor0::GetLowerDimensionHiddenIdss()[itran][idim_low].value);
}); });
...@@ -252,17 +252,17 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a ...@@ -252,17 +252,17 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
TensorAdaptor0{}.GetTransforms()[itran].GetNumOfUpperDimension(); TensorAdaptor0{}.GetTransforms()[itran].GetNumOfUpperDimension();
static_for<0, ndim_up, 1>{}([&](auto idim_up) { static_for<0, ndim_up, 1>{}([&](auto idim_up) {
adaptor0_max_hidden_id = adaptor0_max_hidden_id_ =
math::max(adaptor0_max_hidden_id, math::max(adaptor0_max_hidden_id_,
TensorAdaptor0::GetUpperDimensionHiddenIdss()[itran][idim_up].value); TensorAdaptor0::GetUpperDimensionHiddenIdss()[itran][idim_up].value);
}); });
}); });
return adaptor0_max_hidden_id; return adaptor0_max_hidden_id_;
}(); }();
constexpr index_t adaptor1_min_hidden_id = [&]() { constexpr index_t adaptor1_min_hidden_id = [&]() {
index_t adaptor1_min_hidden_id = NumericLimits<index_t>::Max(); index_t adaptor1_min_hidden_id_ = NumericLimits<index_t>::Max();
static_for<0, TensorAdaptor1::GetNumOfTransform(), 1>{}([&](auto itran) { static_for<0, TensorAdaptor1::GetNumOfTransform(), 1>{}([&](auto itran) {
constexpr index_t ndim_low = constexpr index_t ndim_low =
...@@ -285,7 +285,7 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a ...@@ -285,7 +285,7 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
if(!is_bottom_dim) if(!is_bottom_dim)
{ {
adaptor1_min_hidden_id = math::min(adaptor1_min_hidden_id, low_dim_hidden_id); adaptor1_min_hidden_id_ = math::min(adaptor1_min_hidden_id_, low_dim_hidden_id);
} }
}); });
...@@ -294,13 +294,13 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a ...@@ -294,13 +294,13 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
// get the min of all upper dimensions // get the min of all upper dimensions
static_for<0, ndim_up, 1>{}([&](auto idim_up) { static_for<0, ndim_up, 1>{}([&](auto idim_up) {
adaptor1_min_hidden_id = adaptor1_min_hidden_id_ =
math::min(adaptor1_min_hidden_id, math::min(adaptor1_min_hidden_id_,
TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran][idim_up].value); TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran][idim_up].value);
}); });
}); });
return adaptor1_min_hidden_id; return adaptor1_min_hidden_id_;
}(); }();
constexpr index_t adaptor1_hidden_id_shift = constexpr index_t adaptor1_hidden_id_shift =
...@@ -321,11 +321,11 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a ...@@ -321,11 +321,11 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
// sequence in, sequence out // sequence in, sequence out
constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr
{ {
auto low_dim_hidden_ids_1_mod = to_multi_index(low_dim_hidden_ids_1); auto low_dim_hidden_ids_1_mod_ = to_multi_index(low_dim_hidden_ids_1);
// shift hidden id so every dim id is unique // shift hidden id so every dim id is unique
static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) { static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) {
low_dim_hidden_ids_1_mod(idim_low_1) += adaptor1_hidden_id_shift; low_dim_hidden_ids_1_mod_(idim_low_1) += adaptor1_hidden_id_shift;
}); });
// match hidden id // match hidden id
...@@ -335,13 +335,13 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a ...@@ -335,13 +335,13 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
if constexpr(low_dim_hidden_ids_1[idim_low_1] == if constexpr(low_dim_hidden_ids_1[idim_low_1] ==
TensorAdaptor1::GetBottomDimensionHiddenIds()[idim_bottom_1]) TensorAdaptor1::GetBottomDimensionHiddenIds()[idim_bottom_1])
{ {
low_dim_hidden_ids_1_mod(idim_low_1) = low_dim_hidden_ids_1_mod_(idim_low_1) =
TensorAdaptor0::GetTopDimensionHiddenIds()[idim_bottom_1]; TensorAdaptor0::GetTopDimensionHiddenIds()[idim_bottom_1];
} }
}); });
}); });
return low_dim_hidden_ids_1_mod; return low_dim_hidden_ids_1_mod_;
} }
(); ();
...@@ -367,14 +367,14 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a ...@@ -367,14 +367,14 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
// sequence in, constexpr tuple out // sequence in, constexpr tuple out
constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr
{ {
auto up_dim_hidden_ids_1_mod = to_multi_index(up_dim_hidden_ids_1); auto up_dim_hidden_ids_1_mod_ = to_multi_index(up_dim_hidden_ids_1);
// shift hidden id // shift hidden id
static_for<0, ndim_up_1, 1>{}([&](auto idim_up_1) { static_for<0, ndim_up_1, 1>{}([&](auto idim_up_1) {
up_dim_hidden_ids_1_mod(idim_up_1) += adaptor1_hidden_id_shift; up_dim_hidden_ids_1_mod_(idim_up_1) += adaptor1_hidden_id_shift;
}); });
return up_dim_hidden_ids_1_mod; return up_dim_hidden_ids_1_mod_;
} }
(); ();
......
...@@ -14,7 +14,7 @@ namespace ck { ...@@ -14,7 +14,7 @@ namespace ck {
// 2. ThreadwiseDynamicTensorSliceTransfer_v3 does not keep reference to tensor descriptor // 2. ThreadwiseDynamicTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseDynamicTensorSliceTransfer_v3::Run() does not construct new tensor coordinate // 3. ThreadwiseDynamicTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize, template <index_t BlockSize,
InMemoryDataOperation DstInMemOp, InMemoryDataOperationEnum_t DstInMemOp,
typename BlockSliceLengths, typename BlockSliceLengths,
typename ThreadSliceLengths, typename ThreadSliceLengths,
typename ThreadClusterLengths, typename ThreadClusterLengths,
......
...@@ -14,7 +14,7 @@ namespace ck { ...@@ -14,7 +14,7 @@ namespace ck {
// 2. ThreadwiseDynamicTensorSliceTransfer_v3 does not keep reference to tensor descriptor // 2. ThreadwiseDynamicTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseDynamicTensorSliceTransfer_v3::Run() does not construct new tensor coordinate // 3. ThreadwiseDynamicTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize, template <index_t BlockSize,
InMemoryDataOperation DstInMemOp, InMemoryDataOperationEnum_t DstInMemOp,
typename BlockSliceLengths, typename BlockSliceLengths,
typename ThreadSliceLengths, typename ThreadSliceLengths,
typename ThreadClusterLengths, typename ThreadClusterLengths,
......
#ifndef CK_BLOCKWISE_GEMM_V2R2_HPP #ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP
#define CK_BLOCKWISE_GEMM_V2R2_HPP #define CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_adaptor.hpp" #include "tensor_adaptor.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp" #include "threadwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_contraction.hpp" #include "threadwise_contraction_dlops.hpp"
namespace ck { namespace ck {
...@@ -40,7 +40,7 @@ template <index_t BlockSize, ...@@ -40,7 +40,7 @@ template <index_t BlockSize,
typename std::enable_if<AKMBlockDesc::IsKnownAtCompileTime() && typename std::enable_if<AKMBlockDesc::IsKnownAtCompileTime() &&
BKNBlockDesc::IsKnownAtCompileTime(), BKNBlockDesc::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2 struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
{ {
using AIndex = MultiIndex<3>; using AIndex = MultiIndex<3>;
using BIndex = MultiIndex<3>; using BIndex = MultiIndex<3>;
...@@ -140,7 +140,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2 ...@@ -140,7 +140,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2
static constexpr auto b_k_n0_n1_block_desc_ = MakeBKN0N1BlockDescriptor(BKNBlockDesc{}); static constexpr auto b_k_n0_n1_block_desc_ = MakeBKN0N1BlockDescriptor(BKNBlockDesc{});
public: public:
__device__ BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2() __device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2()
: c_thread_origin_data_idx_{CalculateCM0M1N0N1ThreadOriginOnBlock( : c_thread_origin_data_idx_{CalculateCM0M1N0N1ThreadOriginOnBlock(
get_thread_local_1d_id())}, get_thread_local_1d_id())},
a_thread_copy_{ a_thread_copy_{
...@@ -183,7 +183,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2 ...@@ -183,7 +183,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2
constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1); constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1);
return adaptor.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id(), 0, 0, 0, 0)); return adaptor.CalculateBottomIndex(make_multi_index(thread_id, 0, 0, 0, 0));
} }
__host__ __device__ static constexpr index_t GetABlockAlignment() { return M1PerThreadM11; } __host__ __device__ static constexpr index_t GetABlockAlignment() { return M1PerThreadM11; }
...@@ -207,13 +207,13 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2 ...@@ -207,13 +207,13 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2
CM0M1N0N1ThreadDesc{}.GetLength(I2) == N0, CM0M1N0N1ThreadDesc{}.GetLength(I2) == N0,
"wrong"); "wrong");
auto a_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatA>( auto a_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatA>(
a_k_m0_m1_thread_desc_.GetElementSpaceSize()); a_k_m0_m1_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatB>( auto b_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatB>(
b_k_n0_n1_thread_desc_.GetElementSpaceSize()); b_k_n0_n1_thread_desc_.GetElementSpaceSize());
constexpr auto threadwise_gemm = constexpr auto threadwise_gemm =
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1<FloatA, ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1<FloatA,
FloatB, FloatB,
FloatC, FloatC,
decltype(a_k_m0_m1_thread_desc_), decltype(a_k_m0_m1_thread_desc_),
......
#ifndef CK_BLOCKWISE_GEMM_V2R3_HPP #ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R3_HPP
#define CK_BLOCKWISE_GEMM_V2R3_HPP #define CK_BLOCKWISE_GEMM_DLOPS_V2R3_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_adaptor.hpp" #include "tensor_adaptor.hpp"
#include "threadwise_dynamic_tensor_slice_transfer_v2.hpp" #include "threadwise_dynamic_tensor_slice_transfer_v2.hpp"
#include "threadwise_contraction.hpp" #include "threadwise_contraction_dlops.hpp"
namespace ck { namespace ck {
...@@ -21,6 +21,7 @@ namespace ck { ...@@ -21,6 +21,7 @@ namespace ck {
// 1. CThreadDesc_BM0_BM11_BN0_BN11 is known at compile-time // 1. CThreadDesc_BM0_BM11_BN0_BN11 is known at compile-time
// 2. CThreadBuffer is StaticBuffer // 2. CThreadBuffer is StaticBuffer
// Also assume: // Also assume:
// BM10BN10ThreadClusterBM10Xs::Size() = BM10BN10ThreadClusterBN10Xs::Size() == 2
// BM0 = BN0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization) // BM0 = BN0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
template <index_t BlockSize, template <index_t BlockSize,
typename FloatA, typename FloatA,
...@@ -31,16 +32,16 @@ template <index_t BlockSize, ...@@ -31,16 +32,16 @@ template <index_t BlockSize,
index_t BM1PerThreadBM11, index_t BM1PerThreadBM11,
index_t BN1PerThreadBN11, index_t BN1PerThreadBN11,
index_t BK0PerThread, index_t BK0PerThread,
index_t BM10BN10ThreadClusterBM100, typename BM10BN10ThreadClusterBM10Xs, // Sequence<BM10BN10ThreadClusterBM100,
index_t BM10BN10ThreadClusterBN100, // BM10BN10ThreadClusterBM101, ...>
index_t BM10BN10ThreadClusterBM101, typename BM10BN10ThreadClusterBN10Xs, // Sequence<BM10BN10ThreadClusterBN100,
index_t BM10BN10ThreadClusterBN101, // BM10BN10ThreadClusterBN101, ...>
index_t AThreadCopyScalarPerVector_BM11, index_t AThreadCopyScalarPerVector_BM11,
index_t BThreadCopyScalarPerVector_BN11, index_t BThreadCopyScalarPerVector_BN11,
typename std::enable_if<ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() && typename std::enable_if<ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(), BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2 struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
{ {
using AIndex = MultiIndex<3>; using AIndex = MultiIndex<3>;
using BIndex = MultiIndex<3>; using BIndex = MultiIndex<3>;
...@@ -56,19 +57,17 @@ struct BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_ ...@@ -56,19 +57,17 @@ struct BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_
static constexpr index_t BM = ABlockDesc_BK0_BM_BK1{}.GetLength(I1); static constexpr index_t BM = ABlockDesc_BK0_BM_BK1{}.GetLength(I1);
static constexpr index_t BN = BBlockDesc_BK0_BN_BK1{}.GetLength(I1); static constexpr index_t BN = BBlockDesc_BK0_BN_BK1{}.GetLength(I1);
static constexpr index_t BM100 = BM10BN10ThreadClusterBM100; static constexpr index_t BM100 = BM10BN10ThreadClusterBM10Xs{}[I0];
static constexpr index_t BN100 = BM10BN10ThreadClusterBN100; static constexpr index_t BN100 = BM10BN10ThreadClusterBN10Xs{}[I0];
static constexpr index_t BM101 = BM10BN10ThreadClusterBM101; static constexpr index_t BM101 = BM10BN10ThreadClusterBM10Xs{}[I1];
static constexpr index_t BN101 = BM10BN10ThreadClusterBN101; static constexpr index_t BN101 = BM10BN10ThreadClusterBN10Xs{}[I1];
static constexpr index_t BM11 = BM1PerThreadBM11; static constexpr index_t BM11 = BM1PerThreadBM11;
static constexpr index_t BN11 = BN1PerThreadBN11; static constexpr index_t BN11 = BN1PerThreadBN11;
static constexpr index_t BM1 = static constexpr index_t BM1 = BM100 * BM101 * BM11;
BM10BN10ThreadClusterBM100 * BM10BN10ThreadClusterBM101 * BM1PerThreadBM11; static constexpr index_t BN1 = BN100 * BN101 * BN11;
static constexpr index_t BN1 =
BM10BN10ThreadClusterBN100 * BM10BN10ThreadClusterBN101 * BN1PerThreadBN11;
static constexpr index_t BM0 = BM / BM1; static constexpr index_t BM0 = BM / BM1;
static constexpr index_t BN0 = BN / BN1; static constexpr index_t BN0 = BN / BN1;
...@@ -149,7 +148,7 @@ struct BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_ ...@@ -149,7 +148,7 @@ struct BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_
MakeBBlockDescriptor_BK0_BN0_BN1_BK1(BBlockDesc_BK0_BN_BK1{}); MakeBBlockDescriptor_BK0_BN0_BN1_BK1(BBlockDesc_BK0_BN_BK1{});
public: public:
__device__ BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2() __device__ BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2()
: c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( : c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
get_thread_local_1d_id())}, get_thread_local_1d_id())},
a_thread_copy_{ a_thread_copy_{
...@@ -170,6 +169,11 @@ struct BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_ ...@@ -170,6 +169,11 @@ struct BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_
BBlockDesc_BK0_BN_BK1{}.GetLength(I0), BBlockDesc_BK0_BN_BK1{}.GetLength(I0),
"wrong! K dimension not consistent"); "wrong! K dimension not consistent");
// TODO remove this restriction
static_assert(BM10BN10ThreadClusterBM10Xs::Size() == 2 &&
BM10BN10ThreadClusterBN10Xs::Size() == 2,
"wrong!");
// TODO: remove this restriction // TODO: remove this restriction
static_assert(BM0 == 2 && BN0 == 2, "wrong"); static_assert(BM0 == 2 && BN0 == 2, "wrong");
} }
...@@ -195,14 +199,14 @@ struct BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_ ...@@ -195,14 +199,14 @@ struct BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_
constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1); constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1);
return adaptor.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id(), 0, 0, 0, 0)); return adaptor.CalculateBottomIndex(make_multi_index(thread_id, 0, 0, 0, 0));
} }
template <typename CThreadDesc_BM0_BM11_BN0_BN11, template <typename CThreadDesc_BM0_BM11_BN0_BN11,
typename ABlockBuffer, typename ABlockBuffer,
typename BBlockBuffer, typename BBlockBuffer,
typename CThreadBuffer> typename CThreadBuffer>
__device__ void Run(const CThreadDesc_BM0_BM11_BN0_BN11& c_m0_m1_n0_n1_thread_desc, __device__ void Run(const CThreadDesc_BM0_BM11_BN0_BN11&,
const ABlockBuffer& a_block_buf, const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf, const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const CThreadBuffer& c_thread_buf) const
...@@ -216,13 +220,13 @@ struct BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_ ...@@ -216,13 +220,13 @@ struct BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_
CThreadDesc_BM0_BM11_BN0_BN11{}.GetLength(I2) == BN0, CThreadDesc_BM0_BM11_BN0_BN11{}.GetLength(I2) == BN0,
"wrong"); "wrong");
auto a_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatA>( auto a_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatA>(
a_thread_desc_bk0_bm0_bm1_bk1_.GetElementSpaceSize()); a_thread_desc_bk0_bm0_bm1_bk1_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatB>( auto b_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatB>(
b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize()); b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize());
constexpr auto threadwise_contraction = constexpr auto threadwise_contraction =
ThreadwiseContraction_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1< ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1<
FloatA, FloatA,
FloatB, FloatB,
FloatC, FloatC,
......
#ifndef CK_BLOCKWISE_GEMM_V3_HPP #ifndef CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
#define CK_BLOCKWISE_GEMM_V3_HPP #define CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "threadwise_gemm_v3.hpp" #include "threadwise_gemm_dlops_v3.hpp"
namespace ck { namespace ck {
...@@ -19,7 +19,7 @@ template <index_t BlockSize, ...@@ -19,7 +19,7 @@ template <index_t BlockSize,
index_t EPerThreadLoop, index_t EPerThreadLoop,
index_t ThreadGemmADataPerRead_K, index_t ThreadGemmADataPerRead_K,
index_t ThreadGemmBDataPerRead_W> index_t ThreadGemmBDataPerRead_W>
struct BlockwiseGemm_km_kn_m0m1n0n1_v3 struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
{ {
struct MatrixIndex struct MatrixIndex
{ {
...@@ -51,7 +51,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 ...@@ -51,7 +51,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
ThreadGemmADataPerRead_K, ThreadGemmADataPerRead_K,
1>; 1>;
__device__ BlockwiseGemm_km_kn_m0m1n0n1_v3() __device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3()
: c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())}, : c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())},
a_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.k * KPerThread)} a_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.k * KPerThread)}
{ {
...@@ -138,9 +138,10 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 ...@@ -138,9 +138,10 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
static_assert(WPerThread % WoPerThreadSubC == 0, ""); static_assert(WPerThread % WoPerThreadSubC == 0, "");
// thread A buffer for GEMM // thread A buffer for GEMM
StaticBuffer<AddressSpace::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize()> a_thread_buf; StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize()>
a_thread_buf;
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v3<FloatA, constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3<FloatA,
FloatB, FloatB,
FloatC, FloatC,
decltype(a_thread_mtx_), decltype(a_thread_mtx_),
......
#ifndef CK_BLOCKWISE_GEMM_V2_HPP
#define CK_BLOCKWISE_GEMM_V2_HPP
#include "common_header.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_gemm_v2.hpp"
namespace ck {
// C[M0, M1, N0, N1] += transpose(A[K, M0, M1]) * B[K, N0, N1]
// A and B are visable to the whole block, C is distributed among each thread
// Assume:
// 1. A:
// 1. ABlockDesc is known at compile-time
// 2. ABlockBuffer is DynamicBuffer
// 2. B:
// 1. ABlockDesc is known at compile-time
// 2. BBlockBuffer is DynamicBuffer
// 3. C:
// 1. CThreadDesc is known at compile-time
// 2. CThreadBuffer is StaticBuffer
template <index_t BlockSize,
typename FloatA,
typename FloatB,
typename FloatC,
typename ABlockDesc,
typename BBlockDesc,
typename CThreadDesc,
index_t M1PerThread,
index_t N1PerThread,
index_t KPerThread,
index_t M1N1ThreadClusterM10,
index_t M1N1ThreadClusterN10,
index_t M1N1ThreadClusterM11,
index_t M1N1ThreadClusterN11,
index_t AThreadCopyScalarPerVector_M1,
index_t BThreadCopyScalarPerVector_N1,
typename std::enable_if<ABlockDesc::IsKnownAtCompileTime() &&
BBlockDesc::IsKnownAtCompileTime() &&
CThreadDesc::IsKnownAtCompileTime(),
bool>::type = false>
struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
{
using AIndex = MultiIndex<3>;
using BIndex = MultiIndex<3>;
using CIndex = MultiIndex<4>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
public:
__device__ BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1()
: c_thread_origin_data_idx_{CalculateCThreadOriginDataIndex(get_thread_local_1d_id())},
a_thread_copy_{
make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1])},
b_thread_copy_{
make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3])}
{
static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime() &&
CThreadDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(BlockSize == M1N1ThreadClusterM11 * M1N1ThreadClusterM10 *
M1N1ThreadClusterN11 * M1N1ThreadClusterN10,
"wrong! blocksize and cluster size not consistent");
static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0),
"wrong! K dimension not consistent");
}
__device__ static CIndex CalculateCThreadOriginDataIndex(index_t thread_id)
{
constexpr index_t M0 = ABlockDesc{}.GetLength(I1);
constexpr index_t N0 = BBlockDesc{}.GetLength(I1);
constexpr index_t M1 = ABlockDesc{}.GetLength(I2);
constexpr index_t N1 = BBlockDesc{}.GetLength(I2);
// 4-d data space into 4-d thread space
// upper: {1, M1N1ThreadClusterM10 * M1N1ThreadClusterM11, 1, M1N1ThreadClusterN10 *
// M1N1ThreadClusterN11} lower: {M0, M1, N0, N1}
constexpr auto adaptor0 = make_single_stage_tensor_adaptor(
make_tuple(make_vectorize_transform(M0, 1),
make_vectorize_transform(M1PerThread, M1 / M1PerThread),
make_vectorize_transform(N0, 1),
make_vectorize_transform(N1PerThread, N1 / N1PerThread)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// thread position 4-d thread space
// upper: {M1N1ThreadClusterM10, M1N1ThreadClusterM11, M1N1ThreadClusterN10,
// M1N1ThreadClusterN11} lower: {1, M1N1ThreadClusterM10 * M1N1ThreadClusterM11, 1,
// M1N1ThreadClusterN10 * M1N1ThreadClusterN11}
constexpr auto adaptor1 = make_single_stage_tensor_adaptor(
make_tuple(
make_freeze_transform(make_multi_index(0)),
make_unmerge_transform(make_tuple(M1N1ThreadClusterM10, M1N1ThreadClusterM11)),
make_freeze_transform(make_multi_index(0)),
make_unmerge_transform(make_tuple(M1N1ThreadClusterN10, M1N1ThreadClusterN11))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<>{}, Sequence<0, 1>{}, Sequence<>{}, Sequence<2, 3>{}));
// 4-d thread space to 1-d thread space
// upper: {BlockSize}
// lower: {M1N1ThreadClusterM10, M1N1ThreadClusterM11, M1N1ThreadClusterN10,
// M1N1ThreadClusterN11}
constexpr auto adaptor2 = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M1N1ThreadClusterM10,
M1N1ThreadClusterN10,
M1N1ThreadClusterM11,
M1N1ThreadClusterN11))),
make_tuple(Sequence<0, 2, 1, 3>{}),
make_tuple(Sequence<0>{}));
constexpr auto cluster_desc = chain_tensor_adaptors(adaptor0, adaptor1, adaptor2);
return cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
}
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
auto a_thread_buf =
make_static_buffer<AddressSpace::Vgpr, FloatA>(a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf =
make_static_buffer<AddressSpace::Vgpr, FloatB>(b_thread_desc_.GetElementSpaceSize());
constexpr auto threadwise_gemm =
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1<FloatA,
FloatB,
FloatC,
decltype(a_thread_desc_),
decltype(b_thread_desc_),
CThreadDesc,
Sequence<KPerThread>,
Sequence<M0_, M1PerThread>,
Sequence<N0_, N1PerThread>>{};
constexpr index_t K = ABlockDesc{}.GetLength(I0);
static_for<0, K, KPerThread>{}([&](auto k) {
a_thread_copy_.Run(ABlockDesc{},
make_tuple(k, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0),
a_thread_buf);
b_thread_copy_.Run(BBlockDesc{},
make_tuple(k, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0),
b_thread_buf);
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0),
c_thread_buf,
make_tuple(I0, I0, I0, I0));
});
}
private:
static constexpr index_t M0_ = ABlockDesc{}.GetLength(I1);
static constexpr index_t N0_ = BBlockDesc{}.GetLength(I1);
// A[K, M0, M1]
static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThread>{}, Number<M0_>{}, Number<M1PerThread>{}));
// B[K, N0, N1]
static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThread>{}, Number<N0_>{}, Number<N1PerThread>{}));
using AThreadCopy =
ThreadwiseDynamicTensorSliceTransfer_v4<FloatA,
FloatA,
ABlockDesc,
decltype(a_thread_desc_),
Sequence<KPerThread, M0_, M1PerThread>,
Sequence<0, 1, 2>,
2,
AThreadCopyScalarPerVector_M1,
1>;
using BThreadCopy =
ThreadwiseDynamicTensorSliceTransfer_v4<FloatB,
FloatB,
BBlockDesc,
decltype(b_thread_desc_),
Sequence<KPerThread, N0_, N1PerThread>,
Sequence<0, 1, 2>,
2,
BThreadCopyScalarPerVector_N1,
1>;
CIndex c_thread_origin_data_idx_;
AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
};
// C[M0, M1, N0, N1] += transpose(A[K, M0, M1]) * B[K, N0, N1]
// A and B are visable to the whole block, C is distributed among each thread
// Assume:
// 1. A:
// 1. ABlockDesc is known at compile-time
// 2. ABlockBuffer is DynamicBuffer
// 2. B:
// 1. ABlockDesc is known at compile-time
// 2. BBlockBuffer is DynamicBuffer
// 3. C:
// 1. CThreadDesc is known at compile-time
// 2. CThreadBuffer is StaticBuffer
// Also assume:
// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
template <index_t BlockSize,
typename FloatA,
typename FloatB,
typename FloatC,
typename ABlockDesc,
typename BBlockDesc,
typename CThreadDesc,
index_t M1PerThread,
index_t N1PerThread,
index_t KPerThread,
index_t M1N1ThreadClusterM10,
index_t M1N1ThreadClusterN10,
index_t M1N1ThreadClusterM11,
index_t M1N1ThreadClusterN11,
index_t AThreadCopyScalarPerVector_M1,
index_t BThreadCopyScalarPerVector_N1,
typename std::enable_if<ABlockDesc::IsKnownAtCompileTime() &&
BBlockDesc::IsKnownAtCompileTime() &&
CThreadDesc::IsKnownAtCompileTime(),
bool>::type = false>
struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
{
using AIndex = MultiIndex<3>;
using BIndex = MultiIndex<3>;
using CIndex = MultiIndex<4>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
public:
__device__ BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2()
: c_thread_origin_data_idx_{CalculateCThreadOriginDataIndex(get_thread_local_1d_id())},
a_thread_copy_{
make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1])},
b_thread_copy_{
make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3])}
{
static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime() &&
CThreadDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(BlockSize == M1N1ThreadClusterM11 * M1N1ThreadClusterM10 *
M1N1ThreadClusterN11 * M1N1ThreadClusterN10,
"wrong! blocksize and cluster size not consistent");
static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0),
"wrong! K dimension not consistent");
// TODO: remove this restriction
static_assert(ABlockDesc{}.GetLength(I1) == 2 && BBlockDesc{}.GetLength(I1) == 2 &&
CThreadDesc{}.GetLength(I0) == 2 && CThreadDesc{}.GetLength(I2) == 2,
"wrong");
}
__device__ static CIndex CalculateCThreadOriginDataIndex(index_t thread_id)
{
constexpr index_t M0 = ABlockDesc{}.GetLength(I1);
constexpr index_t N0 = BBlockDesc{}.GetLength(I1);
constexpr index_t M1 = ABlockDesc{}.GetLength(I2);
constexpr index_t N1 = BBlockDesc{}.GetLength(I2);
// 4-d data space into 4-d thread space
constexpr auto adaptor0 = make_single_stage_tensor_adaptor(
make_tuple(make_vectorize_transform(M0, 1),
make_vectorize_transform(M1PerThread, M1 / M1PerThread),
make_vectorize_transform(N0, 1),
make_vectorize_transform(N1PerThread, N1 / N1PerThread)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// thread position 4-d thread space
constexpr auto adaptor1 = make_single_stage_tensor_adaptor(
make_tuple(
make_freeze_transform(make_multi_index(0)),
make_unmerge_transform(make_tuple(M1N1ThreadClusterM10, M1N1ThreadClusterM11)),
make_freeze_transform(make_multi_index(0)),
make_unmerge_transform(make_tuple(M1N1ThreadClusterN10, M1N1ThreadClusterN11))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<>{}, Sequence<0, 1>{}, Sequence<>{}, Sequence<2, 3>{}));
// 4-d thread space to 1-d thread space
constexpr auto adaptor2 = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M1N1ThreadClusterM10,
M1N1ThreadClusterN10,
M1N1ThreadClusterM11,
M1N1ThreadClusterN11))),
make_tuple(Sequence<0, 2, 1, 3>{}),
make_tuple(Sequence<0>{}));
constexpr auto cluster_desc = chain_tensor_adaptors(adaptor0, adaptor1, adaptor2);
return cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
}
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
auto a_thread_buf =
make_static_buffer<AddressSpace::Vgpr, FloatA>(a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf =
make_static_buffer<AddressSpace::Vgpr, FloatB>(b_thread_desc_.GetElementSpaceSize());
constexpr auto threadwise_gemm =
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1<FloatA,
FloatB,
FloatC,
decltype(a_thread_desc_),
decltype(b_thread_desc_),
CThreadDesc,
Sequence<KPerThread>,
Sequence<1, M1PerThread>,
Sequence<1, N1PerThread>>{};
constexpr index_t K = ABlockDesc{}.GetLength(I0);
// read A_sub_0
a_thread_copy_.Run(ABlockDesc{},
make_tuple(I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0),
a_thread_buf);
// read B_sub_0
b_thread_copy_.Run(BBlockDesc{},
make_tuple(I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0),
b_thread_buf);
// read B_sub_1
b_thread_copy_.Run(BBlockDesc{},
make_tuple(I0, I1, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I1, I0),
b_thread_buf);
// read A_sub_1
a_thread_copy_.Run(ABlockDesc{},
make_tuple(I0, I1, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I1, I0),
a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0),
c_thread_buf,
make_tuple(I0, I0, I0, I0));
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0, I0),
b_thread_buf,
make_tuple(I0, I1, I0),
c_thread_buf,
make_tuple(I0, I0, I1, I0));
// loop over rest of k
static_for<KPerThread, K, KPerThread>{}([&](auto k) {
// read A_sub_0
a_thread_copy_.Run(ABlockDesc{},
make_tuple(k, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0),
a_thread_buf);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I1, I0),
b_thread_buf,
make_tuple(I0, I0, I0),
c_thread_buf,
make_tuple(I1, I0, I0, I0));
// read B_sub_0
b_thread_copy_.Run(BBlockDesc{},
make_tuple(k, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0),
b_thread_buf);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I1, I0),
b_thread_buf,
make_tuple(I0, I1, I0),
c_thread_buf,
make_tuple(I1, I0, I1, I0));
// read B_sub_1
b_thread_copy_.Run(BBlockDesc{},
make_tuple(k, I1, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I1, I0),
b_thread_buf);
// read A_sub_1
a_thread_copy_.Run(ABlockDesc{},
make_tuple(k, I1, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I1, I0),
a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0),
c_thread_buf,
make_tuple(I0, I0, I0, I0));
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0, I0),
b_thread_buf,
make_tuple(I0, I1, I0),
c_thread_buf,
make_tuple(I0, I0, I1, I0));
});
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I1, I0),
b_thread_buf,
make_tuple(I0, I0, I0),
c_thread_buf,
make_tuple(I1, I0, I0, I0));
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I1, I0),
b_thread_buf,
make_tuple(I0, I1, I0),
c_thread_buf,
make_tuple(I1, I0, I1, I0));
}
private:
static constexpr index_t M0_ = ABlockDesc{}.GetLength(I1);
static constexpr index_t N0_ = BBlockDesc{}.GetLength(I1);
// A[K, M0, M1]
static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThread>{}, Number<M0_>{}, Number<M1PerThread>{}));
// B[K, N0, N1]
static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThread>{}, Number<N0_>{}, Number<N1PerThread>{}));
using AThreadCopy =
ThreadwiseDynamicTensorSliceTransfer_v4<FloatA,
FloatA,
ABlockDesc,
decltype(a_thread_desc_),
Sequence<KPerThread, 1, M1PerThread>,
Sequence<0, 1, 2>,
2,
AThreadCopyScalarPerVector_M1,
1>;
using BThreadCopy =
ThreadwiseDynamicTensorSliceTransfer_v4<FloatB,
FloatB,
BBlockDesc,
decltype(b_thread_desc_),
Sequence<KPerThread, 1, N1PerThread>,
Sequence<0, 1, 2>,
2,
BThreadCopyScalarPerVector_N1,
1>;
CIndex c_thread_origin_data_idx_;
AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
};
} // namespace ck
#endif
...@@ -138,10 +138,10 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -138,10 +138,10 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
const BBlockBuffer& b_block_buf, const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const CThreadBuffer& c_thread_buf) const
{ {
auto a_thread_buf = auto a_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAB>(
make_static_buffer<AddressSpace::Vgpr, FloatAB>(a_thread_desc_.GetElementSpaceSize()); a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = auto b_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAB>(
make_static_buffer<AddressSpace::Vgpr, FloatAB>(b_thread_desc_.GetElementSpaceSize()); b_thread_desc_.GetElementSpaceSize());
constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0); constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0);
...@@ -358,10 +358,10 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline ...@@ -358,10 +358,10 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
const BBlockBuffer& b_block_buf, const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const CThreadBuffer& c_thread_buf) const
{ {
auto a_thread_buf = auto a_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAB>(
make_static_buffer<AddressSpace::Vgpr, FloatAB>(a_thread_desc_.GetElementSpaceSize()); a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = auto b_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAB>(
make_static_buffer<AddressSpace::Vgpr, FloatAB>(b_thread_desc_.GetElementSpaceSize()); b_thread_desc_.GetElementSpaceSize());
constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0); constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0);
......
#ifndef CK_GRIDWISE_DYNAMIC_CONTRACTION_V1R2_HPP #ifndef CK_GRIDWISE_DYNAMIC_CONTRACTION_DLOPS_V1R2_HPP
#define CK_GRIDWISE_DYNAMIC_CONTRACTION_V1R2_HPP #define CK_GRIDWISE_DYNAMIC_CONTRACTION_DLOPS_V1R2_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_multi_index_transform_helper.hpp" #include "dynamic_multi_index_transform_helper.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "dynamic_tensor_descriptor_helper.hpp"
#include "blockwise_gemm_v2r3.hpp" #include "blockwise_gemm_dlops_v2r3.hpp"
#include "blockwise_dynamic_tensor_slice_transfer_v2.hpp" #include "blockwise_dynamic_tensor_slice_transfer_v2.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp" #include "threadwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_set.hpp" #include "threadwise_dynamic_tensor_slice_set.hpp"
...@@ -25,7 +25,7 @@ __global__ void ...@@ -25,7 +25,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_dynamic_contraction_v1r2( kernel_dynamic_contraction_dlops_v1r2(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
...@@ -55,7 +55,7 @@ template <index_t BlockSize, ...@@ -55,7 +55,7 @@ template <index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation, InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
typename AGridDesc_GK0_GM0_GM1_GK1, typename AGridDesc_GK0_GM0_GM1_GK1,
typename BGridDesc_GK0_GN0_GN1_GK1, typename BGridDesc_GK0_GN0_GN1_GK1,
typename CGridDesc_GM0_GM1_GN0_GN1, typename CGridDesc_GM0_GM1_GN0_GN1,
...@@ -65,10 +65,8 @@ template <index_t BlockSize, ...@@ -65,10 +65,8 @@ template <index_t BlockSize,
index_t BM1PerThreadBM11, index_t BM1PerThreadBM11,
index_t BN1PerThreadBN11, index_t BN1PerThreadBN11,
index_t BK0PerThread, index_t BK0PerThread,
index_t BM10BN10ThreadClusterBM100, typename BM10BN10ThreadClusterBM10Xs,
index_t BM10BN10ThreadClusterBN100, typename BM10BN10ThreadClusterBN10Xs,
index_t BM10BN10ThreadClusterBM101,
index_t BM10BN10ThreadClusterBN101,
typename ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, typename ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, typename ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
...@@ -91,7 +89,7 @@ template <index_t BlockSize, ...@@ -91,7 +89,7 @@ template <index_t BlockSize,
typename CGridIteratorHacks, typename CGridIteratorHacks,
typename AGridMoveSliceWindowIteratorHacks, typename AGridMoveSliceWindowIteratorHacks,
typename BGridMoveSliceWindowIteratorHacks> typename BGridMoveSliceWindowIteratorHacks>
struct GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1 struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -252,9 +250,11 @@ struct GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_ ...@@ -252,9 +250,11 @@ struct GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_
constexpr auto BN = GN0 * GN11; constexpr auto BN = GN0 * GN11;
constexpr auto BM1 = constexpr auto BM1 =
Number<BM10BN10ThreadClusterBM100 * BM10BN10ThreadClusterBM101 * BM1PerThreadBM11>{}; Number<container_reduce(BM10BN10ThreadClusterBM10Xs{}, math::multiplies_v2{}, I1) *
BM1PerThreadBM11>{};
constexpr auto BN1 = constexpr auto BN1 =
Number<BM10BN10ThreadClusterBN100 * BM10BN10ThreadClusterBN101 * BN1PerThreadBN11>{}; Number<container_reduce(BM10BN10ThreadClusterBN10Xs{}, math::multiplies_v2{}, I1) *
BN1PerThreadBN11>{};
constexpr auto BM0 = BM / BM1; constexpr auto BM0 = BM / BM1;
constexpr auto BN0 = BN / BN1; constexpr auto BN0 = BN / BN1;
...@@ -331,11 +331,11 @@ struct GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_ ...@@ -331,11 +331,11 @@ struct GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) integral_constant<bool, HasDoubleTailKBlockLoop>)
{ {
const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>( const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_grid, a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize()); p_a_grid, a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>( const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_b_grid, b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize()); p_b_grid, b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpace::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_grid, c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetElementSpaceSize()); p_c_grid, c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetElementSpaceSize());
const auto GK0 = a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I0); const auto GK0 = a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I0);
...@@ -387,7 +387,7 @@ struct GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_ ...@@ -387,7 +387,7 @@ struct GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1< auto a_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1<
BlockSize, BlockSize,
InMemoryDataOperation::Set, InMemoryDataOperationEnum_t::Set,
Sequence<GK0PerBlock, GM0, 1, GM1PerBlockGM11, GK1.value>, Sequence<GK0PerBlock, GM0, 1, GM1PerBlockGM11, GK1.value>,
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
...@@ -411,7 +411,7 @@ struct GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_ ...@@ -411,7 +411,7 @@ struct GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1< auto b_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1<
BlockSize, BlockSize,
InMemoryDataOperation::Set, InMemoryDataOperationEnum_t::Set,
Sequence<GK0PerBlock, GN0, 1, GN1PerBlockGN11, GK1.value>, Sequence<GK0PerBlock, GN0, 1, GN1PerBlockGN11, GK1.value>,
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1, BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1, BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
...@@ -439,7 +439,7 @@ struct GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_ ...@@ -439,7 +439,7 @@ struct GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_
// c_mtx[GM1PerBlockGM11, GN1PerBlockGN11] is distributed among threads, and saved in // c_mtx[GM1PerBlockGM11, GN1PerBlockGN11] is distributed among threads, and saved in
// register // register
const auto blockwise_gemm = const auto blockwise_gemm =
BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
BlockSize, BlockSize,
FloatAB, FloatAB,
FloatAB, FloatAB,
...@@ -449,10 +449,8 @@ struct GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_ ...@@ -449,10 +449,8 @@ struct GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_
BM1PerThreadBM11, BM1PerThreadBM11,
BN1PerThreadBN11, BN1PerThreadBN11,
BK0PerThread, BK0PerThread,
BM10BN10ThreadClusterBM100, BM10BN10ThreadClusterBM10Xs,
BM10BN10ThreadClusterBN100, BM10BN10ThreadClusterBN10Xs,
BM10BN10ThreadClusterBM101,
BM10BN10ThreadClusterBN101,
BM1PerThreadBM11, BM1PerThreadBM11,
BN1PerThreadBN11>{}; BN1PerThreadBN11>{};
...@@ -474,7 +472,7 @@ struct GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_ ...@@ -474,7 +472,7 @@ struct GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
// register allocation for output // register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>( auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>(
c_thread_desc_bm0_bm1_bn0_bn1.GetElementSpaceSize()); c_thread_desc_bm0_bm1_bn0_bn1.GetElementSpaceSize());
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc, ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
...@@ -488,15 +486,15 @@ struct GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_ ...@@ -488,15 +486,15 @@ struct GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_
constexpr auto a_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0);
auto a_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>( auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_a_block_double, a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize()); p_a_block_double, a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>( auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_b_block_double, b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize()); p_b_block_double, b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>( auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_a_block_double + a_block_aligned_space_size, p_a_block_double + a_block_aligned_space_size,
a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize()); a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>( auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_b_block_double + b_block_aligned_space_size, p_b_block_double + b_block_aligned_space_size,
b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize()); b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
......
#ifndef CK_GRIDWISE_DYNAMIC_GEMM_V1R2_HPP #ifndef CK_GRIDWISE_DYNAMIC_GEMM_DLOPS_V1R2_HPP
#define CK_GRIDWISE_DYNAMIC_GEMM_V1R2_HPP #define CK_GRIDWISE_DYNAMIC_GEMM_DLOPS_V1R2_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_multi_index_transform_helper.hpp" #include "dynamic_multi_index_transform_helper.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "dynamic_tensor_descriptor_helper.hpp"
#include "blockwise_gemm_v2r2.hpp" #include "blockwise_gemm_dlops_v2r2.hpp"
#include "blockwise_dynamic_tensor_slice_transfer.hpp" #include "blockwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp" #include "threadwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_set.hpp" #include "threadwise_dynamic_tensor_slice_set.hpp"
...@@ -26,7 +26,7 @@ __global__ void ...@@ -26,7 +26,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_dynamic_gemm_v1r2( kernel_dynamic_gemm_dlops_v1r2(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
...@@ -52,8 +52,8 @@ __global__ void ...@@ -52,8 +52,8 @@ __global__ void
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailKBlockLoop>{});
} }
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
// pass tensor descriptor by __CONSTANT__ void pointer // pass tensor descriptor by CONSTANT void pointer
// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to // CONSTANT is needed to inform compiler void pointers in the kernel signature are pointing to
// non-modifiable parameter address space, so compiler can enable corresponding optimization // non-modifiable parameter address space, so compiler can enable corresponding optimization
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
...@@ -68,16 +68,16 @@ __global__ void ...@@ -68,16 +68,16 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_dynamic_gemm_v1r2( kernel_dynamic_gemm_dlops_v1r2(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const void __CONSTANT__* p_a_k_m0_m1_grid_desc, const void CONSTANT* p_a_k_m0_m1_grid_desc,
const void __CONSTANT__* p_b_k_n0_n1_grid_desc, const void CONSTANT* p_b_k_n0_n1_grid_desc,
const void __CONSTANT__* p_c_m0_m10_m11_n0_n10_n11_grid_desc, const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
const void __CONSTANT__* p_c_blockid_to_m0_n0_block_cluster_adaptor) const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
{ {
// first cast void __CONSTANT__ void* to void* // first cast void CONSTANT void* to void*
// second cast void* to Desc* // second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4) // the copy constructor of tensor descriptor doesn't take address_space(4)
const auto a_k_m0_m1_grid_desc = const auto a_k_m0_m1_grid_desc =
...@@ -113,7 +113,7 @@ template <index_t BlockSize, ...@@ -113,7 +113,7 @@ template <index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation, InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
typename AKMGridDesc, typename AKMGridDesc,
typename BKNGridDesc, typename BKNGridDesc,
typename CMNGridDesc, typename CMNGridDesc,
...@@ -151,7 +151,7 @@ template <index_t BlockSize, ...@@ -151,7 +151,7 @@ template <index_t BlockSize,
typename CGridIteratorHacks, typename CGridIteratorHacks,
typename AGridMoveSliceWindowIteratorHacks, typename AGridMoveSliceWindowIteratorHacks,
typename BGridMoveSliceWindowIteratorHacks> typename BGridMoveSliceWindowIteratorHacks>
struct GridwiseDynamicGemm_km_kn_mn_v1r2 struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -326,11 +326,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2 ...@@ -326,11 +326,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) integral_constant<bool, HasDoubleTailKBlockLoop>)
{ {
const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>( const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_grid, a_k_m0_m1_grid_desc.GetElementSpaceSize()); p_a_grid, a_k_m0_m1_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>( const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_b_grid, b_k_n0_n1_grid_desc.GetElementSpaceSize()); p_b_grid, b_k_n0_n1_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpace::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize()); p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize());
const auto K = a_k_m0_m1_grid_desc.GetLength(I0); const auto K = a_k_m0_m1_grid_desc.GetLength(I0);
...@@ -373,7 +373,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2 ...@@ -373,7 +373,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize, BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set, InMemoryDataOperationEnum_t::Set,
Sequence<KPerBlock, 1, MPerBlockM1>, Sequence<KPerBlock, 1, MPerBlockM1>,
ABlockTransferThreadSliceLengths_K_M0_M1, ABlockTransferThreadSliceLengths_K_M0_M1,
ABlockTransferThreadClusterLengths_K_M0_M1, ABlockTransferThreadClusterLengths_K_M0_M1,
...@@ -399,7 +399,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2 ...@@ -399,7 +399,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = auto b_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize, BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set, InMemoryDataOperationEnum_t::Set,
Sequence<KPerBlock, 1, NPerBlockN1>, Sequence<KPerBlock, 1, NPerBlockN1>,
BBlockTransferThreadSliceLengths_K_N0_N1, BBlockTransferThreadSliceLengths_K_N0_N1,
BBlockTransferThreadClusterLengths_K_N0_N1, BBlockTransferThreadClusterLengths_K_N0_N1,
...@@ -429,7 +429,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2 ...@@ -429,7 +429,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
// c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in // c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in
// register // register
const auto blockwise_gemm = const auto blockwise_gemm =
BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2<BlockSize, BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2<BlockSize,
FloatAB, FloatAB,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
...@@ -462,7 +462,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2 ...@@ -462,7 +462,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
// register allocation for output // register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>( auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>(
c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize()); c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize());
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc, ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
...@@ -487,16 +487,16 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2 ...@@ -487,16 +487,16 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
constexpr auto b_k_n0_n1_global_move_slice_window_iterator_hack = constexpr auto b_k_n0_n1_global_move_slice_window_iterator_hack =
BGridMoveSliceWindowIteratorHacks{}; BGridMoveSliceWindowIteratorHacks{};
auto a_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>( auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_a_block_double, a_k_m0_m1_block_desc.GetElementSpaceSize()); p_a_block_double, a_k_m0_m1_block_desc.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>( auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_b_block_double, b_k_n0_n1_block_desc.GetElementSpaceSize()); p_b_block_double, b_k_n0_n1_block_desc.GetElementSpaceSize());
auto a_block_odd_buf = auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
make_dynamic_buffer<AddressSpace::Lds>(p_a_block_double + a_block_aligned_space_size, p_a_block_double + a_block_aligned_space_size,
a_k_m0_m1_block_desc.GetElementSpaceSize()); a_k_m0_m1_block_desc.GetElementSpaceSize());
auto b_block_odd_buf = auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
make_dynamic_buffer<AddressSpace::Lds>(p_b_block_double + b_block_aligned_space_size, p_b_block_double + b_block_aligned_space_size,
b_k_n0_n1_block_desc.GetElementSpaceSize()); b_k_n0_n1_block_desc.GetElementSpaceSize());
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include "dynamic_multi_index_transform_helper.hpp" #include "dynamic_multi_index_transform_helper.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "dynamic_tensor_descriptor_helper.hpp"
#include "blockwise_gemm_v2r3.hpp" #include "blockwise_gemm_dlops_v2r3.hpp"
#include "blockwise_dynamic_tensor_slice_transfer_v2.hpp" #include "blockwise_dynamic_tensor_slice_transfer_v2.hpp"
#include "threadwise_dynamic_tensor_slice_transfer_v2.hpp" #include "threadwise_dynamic_tensor_slice_transfer_v2.hpp"
#include "threadwise_dynamic_tensor_slice_set.hpp" #include "threadwise_dynamic_tensor_slice_set.hpp"
...@@ -26,7 +26,7 @@ __global__ void ...@@ -26,7 +26,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_dynamic_gemm_v1r3( kernel_dynamic_gemm_dlops_v1r3(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
...@@ -52,8 +52,8 @@ __global__ void ...@@ -52,8 +52,8 @@ __global__ void
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailKBlockLoop>{});
} }
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
// pass tensor descriptor by __CONSTANT__ void pointer // pass tensor descriptor by CONSTANT void pointer
// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to // CONSTANT is needed to inform compiler void pointers in the kernel signature are pointing to
// non-modifiable parameter address space, so compiler can enable corresponding optimization // non-modifiable parameter address space, so compiler can enable corresponding optimization
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
...@@ -68,16 +68,16 @@ __global__ void ...@@ -68,16 +68,16 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_dynamic_gemm_v1r3( kernel_dynamic_gemm_dlops_v1r3(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const void __CONSTANT__* p_a_k0_m0_m1_k1_grid_desc, const void CONSTANT* p_a_k0_m0_m1_k1_grid_desc,
const void __CONSTANT__* p_b_k0_n0_n1_k1_grid_desc, const void CONSTANT* p_b_k0_n0_n1_k1_grid_desc,
const void __CONSTANT__* p_c_m0_m10_m11_n0_n10_n11_grid_desc, const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
const void __CONSTANT__* p_c_blockid_to_m0_n0_block_cluster_adaptor) const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
{ {
// first cast void __CONSTANT__ void* to void* // first cast void CONSTANT void* to void*
// second cast void* to Desc* // second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4) // the copy constructor of tensor descriptor doesn't take address_space(4)
const auto a_k0_m0_m1_k1_grid_desc = const auto a_k0_m0_m1_k1_grid_desc =
...@@ -113,7 +113,7 @@ template <index_t BlockSize, ...@@ -113,7 +113,7 @@ template <index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation, InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
typename AK0MK1GridDesc, typename AK0MK1GridDesc,
typename BK0NK1GridDesc, typename BK0NK1GridDesc,
typename CMNGridDesc, typename CMNGridDesc,
...@@ -123,10 +123,8 @@ template <index_t BlockSize, ...@@ -123,10 +123,8 @@ template <index_t BlockSize,
index_t M1PerThreadM111, index_t M1PerThreadM111,
index_t N1PerThreadN111, index_t N1PerThreadN111,
index_t KPerThread, index_t KPerThread,
index_t M11N11ThreadClusterM1100, typename M11N11ThreadClusterM110Xs,
index_t M11N11ThreadClusterN1100, typename M11N11ThreadClusterN110Xs,
index_t M11N11ThreadClusterM1101,
index_t M11N11ThreadClusterN1101,
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1, typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1, typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
...@@ -149,7 +147,7 @@ template <index_t BlockSize, ...@@ -149,7 +147,7 @@ template <index_t BlockSize,
typename CGridIteratorHacks, typename CGridIteratorHacks,
typename AGridMoveSliceWindowIteratorHacks, typename AGridMoveSliceWindowIteratorHacks,
typename BGridMoveSliceWindowIteratorHacks> typename BGridMoveSliceWindowIteratorHacks>
struct GridwiseDynamicGemm_km_kn_mn_v1r3 struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -277,9 +275,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3 ...@@ -277,9 +275,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3
const auto N0 = N / N1; const auto N0 = N / N1;
constexpr auto M11 = constexpr auto M11 =
Number<M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101 * M1PerThreadM111>{}; Number<container_reduce(M11N11ThreadClusterM110Xs{}, math::multiplies_v2{}, I1) *
M1PerThreadM111>{};
constexpr auto N11 = constexpr auto N11 =
Number<M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101 * N1PerThreadN111>{}; Number<container_reduce(M11N11ThreadClusterN110Xs{}, math::multiplies_v2{}, I1) *
N1PerThreadN111>{};
constexpr auto M10 = M1 / M11; constexpr auto M10 = M1 / M11;
constexpr auto N10 = N1 / N11; constexpr auto N10 = N1 / N11;
...@@ -333,11 +333,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3 ...@@ -333,11 +333,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) integral_constant<bool, HasDoubleTailKBlockLoop>)
{ {
const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>( const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_grid, a_k0_m0_m1_k1_grid_desc.GetElementSpaceSize()); p_a_grid, a_k0_m0_m1_k1_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>( const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_b_grid, b_k0_n0_n1_k1_grid_desc.GetElementSpaceSize()); p_b_grid, b_k0_n0_n1_k1_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpace::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize()); p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize());
// divide block work by [M, N] // divide block work by [M, N]
...@@ -383,7 +383,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3 ...@@ -383,7 +383,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1< auto a_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1<
BlockSize, BlockSize,
InMemoryDataOperation::Set, InMemoryDataOperationEnum_t::Set,
Sequence<KPerBlock, 1, MPerBlockM1, K1.value>, Sequence<KPerBlock, 1, MPerBlockM1, K1.value>,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
...@@ -407,7 +407,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3 ...@@ -407,7 +407,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1< auto b_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1<
BlockSize, BlockSize,
InMemoryDataOperation::Set, InMemoryDataOperationEnum_t::Set,
Sequence<KPerBlock, 1, NPerBlockN1, K1.value>, Sequence<KPerBlock, 1, NPerBlockN1, K1.value>,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
...@@ -435,7 +435,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3 ...@@ -435,7 +435,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3
// c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in // c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in
// register // register
const auto blockwise_gemm = const auto blockwise_gemm =
BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
BlockSize, BlockSize,
FloatAB, FloatAB,
FloatAB, FloatAB,
...@@ -445,15 +445,13 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3 ...@@ -445,15 +445,13 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3
M1PerThreadM111, M1PerThreadM111,
N1PerThreadN111, N1PerThreadN111,
KPerThread, KPerThread,
M11N11ThreadClusterM1100, M11N11ThreadClusterM110Xs,
M11N11ThreadClusterN1100, M11N11ThreadClusterN110Xs,
M11N11ThreadClusterM1101,
M11N11ThreadClusterN1101,
M1PerThreadM111, M1PerThreadM111,
N1PerThreadN111>{}; N1PerThreadN111>{};
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths(); decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
constexpr auto c_m10_m11_n10_n11_thread_desc = constexpr auto c_m10_m11_n10_n11_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2( make_dynamic_naive_tensor_descriptor_packed_v2(
...@@ -470,7 +468,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3 ...@@ -470,7 +468,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
// register allocation for output // register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>( auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>(
c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize()); c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize());
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc, ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
...@@ -484,16 +482,16 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3 ...@@ -484,16 +482,16 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0);
auto a_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>( auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_a_block_double, a_k0_m0_m1_k1_block_desc.GetElementSpaceSize()); p_a_block_double, a_k0_m0_m1_k1_block_desc.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>( auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_b_block_double, b_k0_n0_n1_k1_block_desc.GetElementSpaceSize()); p_b_block_double, b_k0_n0_n1_k1_block_desc.GetElementSpaceSize());
auto a_block_odd_buf = auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
make_dynamic_buffer<AddressSpace::Lds>(p_a_block_double + a_block_aligned_space_size, p_a_block_double + a_block_aligned_space_size,
a_k0_m0_m1_k1_block_desc.GetElementSpaceSize()); a_k0_m0_m1_k1_block_desc.GetElementSpaceSize());
auto b_block_odd_buf = auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
make_dynamic_buffer<AddressSpace::Lds>(p_b_block_double + b_block_aligned_space_size, p_b_block_double + b_block_aligned_space_size,
b_k0_n0_n1_k1_block_desc.GetElementSpaceSize()); b_k0_n0_n1_k1_block_desc.GetElementSpaceSize());
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
...@@ -610,10 +608,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3 ...@@ -610,10 +608,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3
// output: register to global memory // output: register to global memory
{ {
constexpr index_t M11 = constexpr auto M11 =
M1PerThreadM111 * M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101; Number<container_reduce(M11N11ThreadClusterM110Xs{}, math::multiplies_v2{}, I1) *
constexpr index_t N11 = M1PerThreadM111>{};
N1PerThreadN111 * M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101; constexpr auto N11 =
Number<container_reduce(M11N11ThreadClusterN110Xs{}, math::multiplies_v2{}, I1) *
N1PerThreadN111>{};
constexpr index_t M10 = MPerBlockM1 / M11; constexpr index_t M10 = MPerBlockM1 / M11;
constexpr index_t N10 = NPerBlockN1 / N11; constexpr index_t N10 = NPerBlockN1 / N11;
...@@ -631,7 +631,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3 ...@@ -631,7 +631,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I3]>{})); Number<c_m10_m11_n10_n11_thread_tensor_lengths[I3]>{}));
const auto c_m10_m11_n10_n11_thread_origin_idx_on_block = const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
blockwise_gemm.CalculateCM0M1N0N1ThreadOriginOnBlock(get_thread_local_1d_id()); blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
get_thread_local_1d_id());
ThreadwiseDynamicTensorSliceTransfer_v1r3< ThreadwiseDynamicTensorSliceTransfer_v1r3<
FloatAcc, FloatAcc,
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "dynamic_tensor_descriptor_helper.hpp" #include "dynamic_tensor_descriptor_helper.hpp"
#include "blockwise_dynamic_tensor_slice_transfer.hpp" #include "blockwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp" #include "threadwise_dynamic_tensor_slice_transfer.hpp"
#include "blockwise_gemm_v3.hpp" #include "blockwise_gemm_dlops_v3.hpp"
namespace ck { namespace ck {
...@@ -15,7 +15,7 @@ template <index_t BlockSize, ...@@ -15,7 +15,7 @@ template <index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation, InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
typename AGlobalDesc, typename AGlobalDesc,
typename BGlobalDesc, typename BGlobalDesc,
typename CGlobalDesc, typename CGlobalDesc,
...@@ -47,7 +47,7 @@ template <index_t BlockSize, ...@@ -47,7 +47,7 @@ template <index_t BlockSize,
typename CGlobalIteratorHacks, typename CGlobalIteratorHacks,
typename AGlobalMoveSliceWindowIteratorHacks, typename AGlobalMoveSliceWindowIteratorHacks,
typename BGlobalMoveSliceWindowIteratorHacks> typename BGlobalMoveSliceWindowIteratorHacks>
struct GridwiseDynamicGemm_km_kn_mn_v3 struct GridwiseDynamicGemmDlops_km_kn_mn_v3
{ {
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
...@@ -84,11 +84,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -84,11 +84,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>( const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_global, a_e_k_global_desc.GetElementSpaceSize()); p_a_global, a_e_k_global_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>( const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_b_global, b_e_n_ho_wo_global_desc.GetElementSpaceSize()); p_b_global, b_e_n_ho_wo_global_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpace::Global>( auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_global, c_k_n_ho_wo_global_desc.GetElementSpaceSize()); p_c_global, c_k_n_ho_wo_global_desc.GetElementSpaceSize());
constexpr auto E = EPerBlock * 3 * 3; constexpr auto E = EPerBlock * 3 * 3;
...@@ -100,7 +100,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -100,7 +100,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
const auto Ho = b_e_n_ho_wo_global_desc.GetLength(I2); const auto Ho = b_e_n_ho_wo_global_desc.GetLength(I2);
const auto Wo = b_e_n_ho_wo_global_desc.GetLength(I3); const auto Wo = b_e_n_ho_wo_global_desc.GetLength(I3);
// divide block work by [M, N] // divide block work by [M, N]
#if 0 #if 0
const auto k_block_work_num = K / Number<KPerBlock>{}; const auto k_block_work_num = K / Number<KPerBlock>{};
const auto ho_block_work_num = Ho / Number<HoPerBlock>{}; const auto ho_block_work_num = Ho / Number<HoPerBlock>{};
...@@ -152,7 +152,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -152,7 +152,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<KPerThread>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{})); Number<KPerThread>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
auto blockwise_gemm = BlockwiseGemm_km_kn_m0m1n0n1_v3<BlockSize, auto blockwise_gemm =
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3<BlockSize,
FloatAB, FloatAB,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
...@@ -184,7 +185,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -184,7 +185,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize, BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set, InMemoryDataOperationEnum_t::Set,
Sequence<E, KPerBlock>, Sequence<E, KPerBlock>,
ABlockTransferThreadSliceLengths_E_K, ABlockTransferThreadSliceLengths_E_K,
ABlockTransferThreadClusterLengths_E_K, ABlockTransferThreadClusterLengths_E_K,
...@@ -225,11 +226,13 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -225,11 +226,13 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
true>(b_e_n_ho_wo_global_desc, true>(b_e_n_ho_wo_global_desc,
make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global)); make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global));
auto a_block_buf = make_dynamic_buffer<AddressSpace::Lds>(p_shared_block, auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
a_e_k_desc.GetElementSpaceSize()); p_shared_block, a_e_k_desc.GetElementSpaceSize());
// register allocation for output // register allocation for output
StaticBuffer<AddressSpace::Vgpr, FloatAcc, c_k_n_ho_wo_thread_desc.GetElementSpaceSize()> StaticBuffer<AddressSpaceEnum_t::Vgpr,
FloatAcc,
c_k_n_ho_wo_thread_desc.GetElementSpaceSize()>
c_thread_buf; c_thread_buf;
// initialize output thread tensor // initialize output thread tensor
...@@ -252,7 +255,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -252,7 +255,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
BGlobalMoveSliceWindowIteratorHacks{}; BGlobalMoveSliceWindowIteratorHacks{};
// double regsiter buffer for b // double regsiter buffer for b
StaticBuffer<AddressSpace::Vgpr, FloatAB, b_e_n_ho_wo_thread_desc.GetElementSpaceSize()> StaticBuffer<AddressSpaceEnum_t::Vgpr,
FloatAB,
b_e_n_ho_wo_thread_desc.GetElementSpaceSize()>
b_thread_even_buf, b_thread_odd_buf; b_thread_even_buf, b_thread_odd_buf;
// LDS double buffer: preload data // LDS double buffer: preload data
......
...@@ -61,10 +61,10 @@ __global__ void ...@@ -61,10 +61,10 @@ __global__ void
kernel_dynamic_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid, kernel_dynamic_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const void __CONSTANT__* p_a_k0_m_k1_grid_desc, const void CONSTANT* p_a_k0_m_k1_grid_desc,
const void __CONSTANT__* p_b_k0_n_k1_grid_desc, const void CONSTANT* p_b_k0_n_k1_grid_desc,
const void __CONSTANT__* p_c_m0_m1_m2_n_grid_desc, const void CONSTANT* p_c_m0_m1_m2_n_grid_desc,
const void __CONSTANT__* p_c_block_cluster_adaptor) const void CONSTANT* p_c_block_cluster_adaptor)
{ {
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
...@@ -95,7 +95,7 @@ template <index_t BlockSize, ...@@ -95,7 +95,7 @@ template <index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation, InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
typename AK0MK1GridDesc, typename AK0MK1GridDesc,
typename BK0NK1GridDesc, typename BK0NK1GridDesc,
typename CMNGridDesc, typename CMNGridDesc,
...@@ -274,11 +274,11 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -274,11 +274,11 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
const auto a_grid_buf = make_dynamic_buffer<AddressSpace::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize()); p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpace::Global>( const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize()); p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpace::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_grid, c_m0_m1_m2_n_grid_desc.GetElementSpaceSize()); p_c_grid, c_m0_m1_m2_n_grid_desc.GetElementSpaceSize());
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
...@@ -312,7 +312,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -312,7 +312,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize, BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set, InMemoryDataOperationEnum_t::Set,
Sequence<KPerBlock, MPerBlock, K1>, Sequence<KPerBlock, MPerBlock, K1>,
ABlockTransferThreadSliceLengths_K0_M_K1, ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1,
...@@ -339,7 +339,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -339,7 +339,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = auto b_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize, BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set, InMemoryDataOperationEnum_t::Set,
Sequence<KPerBlock, NPerBlock, K1>, Sequence<KPerBlock, NPerBlock, K1>,
BBlockTransferThreadSliceLengths_K0_N_K1, BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1,
...@@ -413,7 +413,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -413,7 +413,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr auto c_mr_nr_blk_desc = make_dynamic_naive_tensor_descriptor_packed_v2( constexpr auto c_mr_nr_blk_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{})); make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
StaticBuffer<AddressSpace::Vgpr, StaticBuffer<AddressSpaceEnum_t::Vgpr,
vector_type<FloatAcc, BlkSize>, vector_type<FloatAcc, BlkSize>,
c_mr_nr_blk_desc.GetElementSpaceSize()> c_mr_nr_blk_desc.GetElementSpaceSize()>
c_thread_buf; c_thread_buf;
...@@ -442,9 +442,9 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -442,9 +442,9 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr auto b_k0_n_k1_grid_move_slice_window_iterator_hack = constexpr auto b_k0_n_k1_grid_move_slice_window_iterator_hack =
BGridMoveSliceWindowIteratorHacks{}; BGridMoveSliceWindowIteratorHacks{};
auto a_block_buf = make_dynamic_buffer<AddressSpace::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize()); p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpace::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize()); p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize());
// preload data into LDS // preload data into LDS
...@@ -515,7 +515,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -515,7 +515,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
Number<M2>{}, Number<M2>{},
Number<1>{})); Number<1>{}));
StaticBuffer<AddressSpace::Vgpr, FloatC, c_m0_m1_m2_n_thread_desc.GetElementSpaceSize()> StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatC, c_m0_m1_m2_n_thread_desc.GetElementSpaceSize()>
c_blk_buf_; c_blk_buf_;
static_for<0, MRepeat, 1>{}([&](auto mr_i) { static_for<0, MRepeat, 1>{}([&](auto mr_i) {
...@@ -585,7 +585,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -585,7 +585,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
I1, I1, I1, I1, Number<M0>{}, Number<1>{}, Number<M2>{}, Number<1>{})); I1, I1, I1, I1, Number<M0>{}, Number<1>{}, Number<M2>{}, Number<1>{}));
StaticBuffer<AddressSpace::Vgpr, FloatC, BlkSize> c_blk_buf_; StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatC, BlkSize> c_blk_buf_;
// calculate origin of thread output tensor on global memory // calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index // blockwise GEMM c matrix starting index
......
#ifndef CK_THREADWISE_CONTRACTION_HPP #ifndef CK_THREADWISE_CONTRACTION_DLOPS_HPP
#define CK_THREADWISE_CONTRACTION_HPP #define CK_THREADWISE_CONTRACTION_DLOPS_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "math.hpp" #include "math.hpp"
...@@ -25,9 +25,9 @@ template <typename FloatA, ...@@ -25,9 +25,9 @@ template <typename FloatA,
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(), CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1 struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1
{ {
__device__ constexpr ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1() __device__ constexpr ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1()
{ {
static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
...@@ -71,8 +71,6 @@ struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1 ...@@ -71,8 +71,6 @@ struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto TK = TKLengths{}[I0]; constexpr auto TK = TKLengths{}[I0];
constexpr auto TM0 = TMLengths{}[I0]; constexpr auto TM0 = TMLengths{}[I0];
...@@ -131,9 +129,9 @@ template <typename FloatA, ...@@ -131,9 +129,9 @@ template <typename FloatA,
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(), CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct ThreadwiseContraction_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1 struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
{ {
__device__ constexpr ThreadwiseContraction_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1() __device__ constexpr ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1()
{ {
static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
...@@ -177,8 +175,6 @@ struct ThreadwiseContraction_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_T ...@@ -177,8 +175,6 @@ struct ThreadwiseContraction_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_T
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr index_t TK0 = TKLengths{}[I0]; constexpr index_t TK0 = TKLengths{}[I0];
constexpr index_t TK1 = TKLengths{}[I1]; constexpr index_t TK1 = TKLengths{}[I1];
......
...@@ -54,7 +54,7 @@ template <typename SrcData, ...@@ -54,7 +54,7 @@ template <typename SrcData,
typename DimAccessOrder, typename DimAccessOrder,
index_t DstVectorDim, index_t DstVectorDim,
index_t DstScalarPerVector, index_t DstScalarPerVector,
InMemoryDataOperation DstInMemOp, InMemoryDataOperationEnum_t DstInMemOp,
index_t DstScalarStrideInVector, index_t DstScalarStrideInVector,
bool DstResetCoordinateAfterRun, bool DstResetCoordinateAfterRun,
typename std::enable_if<SrcDesc::IsKnownAtCompileTime(), bool>::type = false> typename std::enable_if<SrcDesc::IsKnownAtCompileTime(), bool>::type = false>
...@@ -159,9 +159,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -159,9 +159,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) { static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
// judge move forward or move backward // judge move forward or move backward
constexpr auto forward_sweep = [&]() { constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep; StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep(I0) = true; forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) { static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_access_idx[I0]; index_t tmp = ordered_access_idx[I0];
...@@ -170,10 +170,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -170,10 +170,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j];
}); });
forward_sweep(i) = tmp % 2 == 0; forward_sweep_(i) = tmp % 2 == 0;
}); });
return forward_sweep; return forward_sweep_;
}(); }();
// calculate dst data index // calculate dst data index
...@@ -186,10 +186,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -186,10 +186,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
: ordered_access_lengths[i] - 1 - ordered_access_idx[i]; : ordered_access_lengths[i] - 1 - ordered_access_idx[i];
}); });
auto dst_data_idx = container_reorder_given_old2new(ordered_idx, dim_access_order) * return container_reorder_given_old2new(ordered_idx, dim_access_order) *
dst_scalar_per_access; dst_scalar_per_access;
return dst_data_idx;
}(); }();
typename vector_type_maker<DstData, DstScalarPerVector>::type dst_vector; typename vector_type_maker<DstData, DstScalarPerVector>::type dst_vector;
...@@ -217,17 +215,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -217,17 +215,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
constexpr auto move_on_dim = [&]() constexpr constexpr auto move_on_dim = [&]() constexpr
{ {
StaticallyIndexedArray<bool, nDim> move_on_dim; StaticallyIndexedArray<bool, nDim> move_on_dim_;
static_for<0, nDim, 1>{}([&](auto i) { static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1; move_on_dim_(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) { static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1; move_on_dim_(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1;
}); });
}); });
return move_on_dim; return move_on_dim_;
} }
(); ();
...@@ -295,9 +293,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -295,9 +293,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
// judge move forward or move backward during the last iteration // judge move forward or move backward during the last iteration
constexpr auto forward_sweep = [&]() { constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep; StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep(I0) = true; forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) { static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_access_lengths[I0] - 1; index_t tmp = ordered_access_lengths[I0] - 1;
...@@ -306,10 +304,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -306,10 +304,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1; tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1;
}); });
forward_sweep(i) = tmp % 2 == 0; forward_sweep_(i) = tmp % 2 == 0;
}); });
return forward_sweep; return forward_sweep_;
}(); }();
// calculate dst data index after last iteration in Run(), if it has not being reset by // calculate dst data index after last iteration in Run(), if it has not being reset by
...@@ -321,19 +319,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -321,19 +319,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0; ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0;
}); });
auto dst_data_idx = container_reorder_given_old2new(ordered_idx, dim_access_order) * return container_reorder_given_old2new(ordered_idx, dim_access_order) *
dst_scalar_per_access; dst_scalar_per_access;
return dst_data_idx;
}(); }();
// //
constexpr auto reset_dst_data_step = [&]() { constexpr auto reset_dst_data_step = [&]() {
Index reset_dst_data_step; Index reset_dst_data_step_;
static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step(i) = -dst_data_idx[i]; }); static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; });
return reset_dst_data_step; return reset_dst_data_step_;
}(); }();
return reset_dst_data_step; return reset_dst_data_step;
...@@ -478,9 +474,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -478,9 +474,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) { static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
// judge move forward or move backward // judge move forward or move backward
constexpr auto forward_sweep = [&]() { constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep; StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep(I0) = true; forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) { static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_access_idx[I0]; index_t tmp = ordered_access_idx[I0];
...@@ -489,10 +485,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -489,10 +485,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j];
}); });
forward_sweep(i) = tmp % 2 == 0; forward_sweep_(i) = tmp % 2 == 0;
}); });
return forward_sweep; return forward_sweep_;
}(); }();
// calculate src data index // calculate src data index
...@@ -505,10 +501,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -505,10 +501,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
: ordered_access_lengths[i] - 1 - ordered_access_idx[i]; : ordered_access_lengths[i] - 1 - ordered_access_idx[i];
}); });
auto src_data_idx = container_reorder_given_old2new(ordered_idx, dim_access_order) * return container_reorder_given_old2new(ordered_idx, dim_access_order) *
src_scalar_per_access; src_scalar_per_access;
return src_data_idx;
}(); }();
typename vector_type_maker<SrcData, SrcScalarPerVector>::type src_vector; typename vector_type_maker<SrcData, SrcScalarPerVector>::type src_vector;
...@@ -534,17 +528,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -534,17 +528,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
constexpr auto move_on_dim = [&]() constexpr constexpr auto move_on_dim = [&]() constexpr
{ {
StaticallyIndexedArray<bool, nDim> move_on_dim; StaticallyIndexedArray<bool, nDim> move_on_dim_;
static_for<0, nDim, 1>{}([&](auto i) { static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1; move_on_dim_(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) { static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1; move_on_dim_(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1;
}); });
}); });
return move_on_dim; return move_on_dim_;
} }
(); ();
...@@ -612,9 +606,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -612,9 +606,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
// judge move forward or move backward during the last iteration // judge move forward or move backward during the last iteration
constexpr auto forward_sweep = [&]() { constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep; StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep(I0) = true; forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) { static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_access_lengths[I0] - 1; index_t tmp = ordered_access_lengths[I0] - 1;
...@@ -623,10 +617,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -623,10 +617,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1; tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1;
}); });
forward_sweep(i) = tmp % 2 == 0; forward_sweep_(i) = tmp % 2 == 0;
}); });
return forward_sweep; return forward_sweep_;
}(); }();
// calculate src data index after last iteration in Run(), if it has not being reset by // calculate src data index after last iteration in Run(), if it has not being reset by
...@@ -638,19 +632,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -638,19 +632,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0; ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0;
}); });
auto src_data_idx = container_reorder_given_old2new(ordered_idx, dim_access_order) * return container_reorder_given_old2new(ordered_idx, dim_access_order) *
src_scalar_per_access; src_scalar_per_access;
return src_data_idx;
}(); }();
// //
constexpr auto reset_src_data_step = [&]() { constexpr auto reset_src_data_step = [&]() {
Index reset_src_data_step; Index reset_src_data_step_;
static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step(i) = -src_data_idx[i]; }); static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; });
return reset_src_data_step; return reset_src_data_step_;
}(); }();
return reset_src_data_step; return reset_src_data_step;
...@@ -682,7 +674,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -682,7 +674,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
// 3. src_slice_origin and dst_slice_origin are not known at compile-time, // 3. src_slice_origin and dst_slice_origin are not known at compile-time,
// 4. Use thread buffer // 4. Use thread buffer
template <typename SliceLengths, template <typename SliceLengths,
InMemoryDataOperation DstInMemOp, InMemoryDataOperationEnum_t DstInMemOp,
typename SrcData, typename SrcData,
typename DstData, typename DstData,
typename SrcDesc, typename SrcDesc,
...@@ -739,8 +731,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -739,8 +731,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
const SrcBuffer& src_buf, const SrcBuffer& src_buf,
const SrcIteratorHacks& src_iterator_hacks) const SrcIteratorHacks& src_iterator_hacks)
{ {
static_assert(SrcBuffer::GetAddressSpace() == AddressSpace::Global or static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
SrcBuffer::GetAddressSpace() == AddressSpace::Lds, SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
"wrong!"); "wrong!");
static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>, static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>,
...@@ -797,9 +789,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -797,9 +789,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) { static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
// judge move forward or move backward // judge move forward or move backward
constexpr auto forward_sweep = [&]() { constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep; StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep(I0) = true; forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) { static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_src_access_idx[I0]; index_t tmp = ordered_src_access_idx[I0];
...@@ -808,10 +800,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -808,10 +800,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j];
}); });
forward_sweep(i) = tmp % 2 == 0; forward_sweep_(i) = tmp % 2 == 0;
}); });
return forward_sweep; return forward_sweep_;
}(); }();
// calculate src data index // calculate src data index
...@@ -824,11 +816,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -824,11 +816,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
ordered_src_access_idx[i]; ordered_src_access_idx[i];
}); });
auto src_data_idx = return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
src_scalar_per_access; src_scalar_per_access;
return src_data_idx;
}(); }();
vector_type_maker_t<SrcData, SrcScalarPerVector> src_tmp_vector; vector_type_maker_t<SrcData, SrcScalarPerVector> src_tmp_vector;
...@@ -852,18 +841,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -852,18 +841,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
constexpr auto move_on_dim = [&]() constexpr constexpr auto move_on_dim = [&]() constexpr
{ {
StaticallyIndexedArray<bool, nDim> move_on_dim; StaticallyIndexedArray<bool, nDim> move_on_dim_;
static_for<0, nDim, 1>{}([&](auto i) { static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1; move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) { static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim(i) &= move_on_dim_(i) &=
ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1; ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1;
}); });
}); });
return move_on_dim; return move_on_dim_;
} }
(); ();
...@@ -900,8 +889,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -900,8 +889,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
DstBuffer& dst_buf, DstBuffer& dst_buf,
const DstIteratorHacks& dst_iterator_hacks) const DstIteratorHacks& dst_iterator_hacks)
{ {
static_assert(DstBuffer::GetAddressSpace() == AddressSpace::Global or static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
DstBuffer::GetAddressSpace() == AddressSpace::Lds, DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
"wrong!"); "wrong!");
static_assert(is_same<remove_cv_t<remove_reference_t<typename DstBuffer::type>>, static_assert(is_same<remove_cv_t<remove_reference_t<typename DstBuffer::type>>,
...@@ -962,9 +951,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -962,9 +951,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
static_ford<decltype(ordered_dst_access_lengths)>{}([&](auto ordered_dst_access_idx) { static_ford<decltype(ordered_dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
// judge move forward or move backward // judge move forward or move backward
constexpr auto forward_sweep = [&]() { constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep; StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep(I0) = true; forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) { static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_idx[I0]; index_t tmp = ordered_dst_access_idx[I0];
...@@ -973,10 +962,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -973,10 +962,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j];
}); });
forward_sweep(i) = tmp % 2 == 0; forward_sweep_(i) = tmp % 2 == 0;
}); });
return forward_sweep; return forward_sweep_;
}(); }();
// calculate dst data index // calculate dst data index
...@@ -989,11 +978,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -989,11 +978,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
ordered_dst_access_idx[i]; ordered_dst_access_idx[i];
}); });
auto dst_data_idx = return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
dst_scalar_per_access; dst_scalar_per_access;
return dst_data_idx;
}(); }();
vector_type_maker_t<DstData, DstScalarPerVector> dst_tmp_vector; vector_type_maker_t<DstData, DstScalarPerVector> dst_tmp_vector;
...@@ -1019,18 +1005,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -1019,18 +1005,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
constexpr auto move_on_dim = [&]() constexpr constexpr auto move_on_dim = [&]() constexpr
{ {
StaticallyIndexedArray<bool, nDim> move_on_dim; StaticallyIndexedArray<bool, nDim> move_on_dim_;
static_for<0, nDim, 1>{}([&](auto i) { static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1; move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) { static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim(i) &= move_on_dim_(i) &=
ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1; ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1;
}); });
}); });
return move_on_dim; return move_on_dim_;
} }
(); ();
...@@ -1108,9 +1094,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -1108,9 +1094,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
// judge move forward or move backward during the last iteration // judge move forward or move backward during the last iteration
constexpr auto forward_sweep = [&]() { constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep; StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep(I0) = true; forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) { static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_src_access_lengths[I0] - 1; index_t tmp = ordered_src_access_lengths[I0] - 1;
...@@ -1119,10 +1105,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -1119,10 +1105,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1;
}); });
forward_sweep(i) = tmp % 2 == 0; forward_sweep_(i) = tmp % 2 == 0;
}); });
return forward_sweep; return forward_sweep_;
}(); }();
// calculate src data index after last iteration in RunRead(), if it has not being reset by // calculate src data index after last iteration in RunRead(), if it has not being reset by
...@@ -1134,19 +1120,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -1134,19 +1120,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0; ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0;
}); });
auto src_data_idx = container_reorder_given_old2new(ordered_idx, src_dim_access_order) * return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
src_scalar_per_access; src_scalar_per_access;
return src_data_idx;
}(); }();
// //
constexpr auto reset_src_data_step = [&]() { constexpr auto reset_src_data_step = [&]() {
Index reset_src_data_step; Index reset_src_data_step_;
static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step(i) = -src_data_idx[i]; }); static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; });
return reset_src_data_step; return reset_src_data_step_;
}(); }();
return reset_src_data_step; return reset_src_data_step;
...@@ -1170,9 +1154,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -1170,9 +1154,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
// judge move forward or move backward during the last iteration // judge move forward or move backward during the last iteration
constexpr auto forward_sweep = [&]() { constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep; StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep(I0) = true; forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) { static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_lengths[I0] - 1; index_t tmp = ordered_dst_access_lengths[I0] - 1;
...@@ -1181,10 +1165,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -1181,10 +1165,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1;
}); });
forward_sweep(i) = tmp % 2 == 0; forward_sweep_(i) = tmp % 2 == 0;
}); });
return forward_sweep; return forward_sweep_;
}(); }();
// calculate dst data index after last iteration in RunWrite(), if it has not being reset by // calculate dst data index after last iteration in RunWrite(), if it has not being reset by
...@@ -1196,19 +1180,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -1196,19 +1180,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0; ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0;
}); });
auto dst_data_idx = container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
dst_scalar_per_access; dst_scalar_per_access;
return dst_data_idx;
}(); }();
// //
constexpr auto reset_dst_data_step = [&]() { constexpr auto reset_dst_data_step = [&]() {
Index reset_dst_data_step; Index reset_dst_data_step_;
static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step(i) = -dst_data_idx[i]; }); static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; });
return reset_dst_data_step; return reset_dst_data_step_;
}(); }();
return reset_dst_data_step; return reset_dst_data_step;
...@@ -1270,7 +1252,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -1270,7 +1252,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize(); static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize();
StaticBuffer<AddressSpace::Vgpr, SrcData, buffer_size_> buffer_; StaticBuffer<AddressSpaceEnum_t::Vgpr, SrcData, buffer_size_> buffer_;
SrcCoord src_coord_; SrcCoord src_coord_;
DstCoord dst_coord_; DstCoord dst_coord_;
...@@ -1357,9 +1339,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4 ...@@ -1357,9 +1339,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{}); constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{});
constexpr auto dst_origin_idx = to_multi_index(DstOriginIdx{}); constexpr auto dst_origin_idx = to_multi_index(DstOriginIdx{});
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
// scalar per access of each dim // scalar per access of each dim
constexpr auto src_scalar_per_access = generate_sequence_v2( constexpr auto src_scalar_per_access = generate_sequence_v2(
[&](auto i) constexpr { [&](auto i) constexpr {
......
...@@ -13,7 +13,7 @@ namespace ck { ...@@ -13,7 +13,7 @@ namespace ck {
// 3. src_slice_origin and dst_slice_origin are not known at compile-time, // 3. src_slice_origin and dst_slice_origin are not known at compile-time,
// 4. Use thread buffer // 4. Use thread buffer
template <typename SliceLengths, template <typename SliceLengths,
InMemoryDataOperation DstInMemOp, InMemoryDataOperationEnum_t DstInMemOp,
typename SrcData, typename SrcData,
typename DstData, typename DstData,
typename SrcDesc, typename SrcDesc,
...@@ -77,8 +77,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 ...@@ -77,8 +77,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
const SrcBuffer& src_buf, const SrcBuffer& src_buf,
const SrcIteratorHacks& src_iterator_hacks) const SrcIteratorHacks& src_iterator_hacks)
{ {
static_assert(SrcBuffer::GetAddressSpace() == AddressSpace::Global or static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
SrcBuffer::GetAddressSpace() == AddressSpace::Lds, SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
"wrong!"); "wrong!");
static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>, static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>,
...@@ -140,9 +140,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 ...@@ -140,9 +140,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) { static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
// judge move forward or move backward // judge move forward or move backward
constexpr auto forward_sweep = [&]() { constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep; StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep(I0) = true; forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) { static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_src_access_idx[I0]; index_t tmp = ordered_src_access_idx[I0];
...@@ -151,10 +151,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 ...@@ -151,10 +151,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j];
}); });
forward_sweep(i) = tmp % 2 == 0; forward_sweep_(i) = tmp % 2 == 0;
}); });
return forward_sweep; return forward_sweep_;
}(); }();
// calculate src data index // calculate src data index
...@@ -167,11 +167,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 ...@@ -167,11 +167,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
ordered_src_access_idx[i]; ordered_src_access_idx[i];
}); });
auto src_data_idx = return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
src_vector_tensor_lengths; src_vector_tensor_lengths;
return src_data_idx;
}(); }();
vector_type_maker_t<SrcData, src_vector_desc.GetElementSpaceSize()> src_vector; vector_type_maker_t<SrcData, src_vector_desc.GetElementSpaceSize()> src_vector;
...@@ -201,18 +198,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 ...@@ -201,18 +198,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
constexpr auto move_on_dim = [&]() constexpr constexpr auto move_on_dim = [&]() constexpr
{ {
StaticallyIndexedArray<bool, nDim> move_on_dim; StaticallyIndexedArray<bool, nDim> move_on_dim_;
static_for<0, nDim, 1>{}([&](auto i) { static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1; move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) { static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim(i) &= move_on_dim_(i) &=
ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1; ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1;
}); });
}); });
return move_on_dim; return move_on_dim_;
} }
(); ();
...@@ -249,8 +246,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 ...@@ -249,8 +246,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
DstBuffer& dst_buf, DstBuffer& dst_buf,
const DstIteratorHacks& dst_iterator_hacks) const DstIteratorHacks& dst_iterator_hacks)
{ {
static_assert(DstBuffer::GetAddressSpace() == AddressSpace::Global or static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
DstBuffer::GetAddressSpace() == AddressSpace::Lds, DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
"wrong!"); "wrong!");
static_assert(is_same<remove_cv_t<remove_reference_t<typename DstBuffer::type>>, static_assert(is_same<remove_cv_t<remove_reference_t<typename DstBuffer::type>>,
...@@ -316,9 +313,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 ...@@ -316,9 +313,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
static_ford<decltype(ordered_dst_access_lengths)>{}([&](auto ordered_dst_access_idx) { static_ford<decltype(ordered_dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
// judge move forward or move backward // judge move forward or move backward
constexpr auto forward_sweep = [&]() { constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep; StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep(I0) = true; forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) { static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_idx[I0]; index_t tmp = ordered_dst_access_idx[I0];
...@@ -327,10 +324,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 ...@@ -327,10 +324,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j];
}); });
forward_sweep(i) = tmp % 2 == 0; forward_sweep_(i) = tmp % 2 == 0;
}); });
return forward_sweep; return forward_sweep_;
}(); }();
// calculate dst data index // calculate dst data index
...@@ -343,11 +340,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 ...@@ -343,11 +340,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
ordered_dst_access_idx[i]; ordered_dst_access_idx[i];
}); });
auto dst_data_idx = return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
dst_vector_tensor_lengths; dst_vector_tensor_lengths;
return dst_data_idx;
}(); }();
vector_type_maker_t<DstData, dst_vector_desc.GetElementSpaceSize()> dst_vector; vector_type_maker_t<DstData, dst_vector_desc.GetElementSpaceSize()> dst_vector;
...@@ -379,18 +373,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 ...@@ -379,18 +373,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
constexpr auto move_on_dim = [&]() constexpr constexpr auto move_on_dim = [&]() constexpr
{ {
StaticallyIndexedArray<bool, nDim> move_on_dim; StaticallyIndexedArray<bool, nDim> move_on_dim_;
static_for<0, nDim, 1>{}([&](auto i) { static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1; move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) { static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim(i) &= move_on_dim_(i) &=
ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1; ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1;
}); });
}); });
return move_on_dim; return move_on_dim_;
} }
(); ();
...@@ -463,9 +457,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 ...@@ -463,9 +457,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
// judge move forward or move backward during the last iteration // judge move forward or move backward during the last iteration
constexpr auto forward_sweep = [&]() { constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep; StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep(I0) = true; forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) { static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_src_access_lengths[I0] - 1; index_t tmp = ordered_src_access_lengths[I0] - 1;
...@@ -474,10 +468,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 ...@@ -474,10 +468,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1;
}); });
forward_sweep(i) = tmp % 2 == 0; forward_sweep_(i) = tmp % 2 == 0;
}); });
return forward_sweep; return forward_sweep_;
}(); }();
// calculate src data index after last iteration in RunRead(), if it has not being reset by // calculate src data index after last iteration in RunRead(), if it has not being reset by
...@@ -489,19 +483,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 ...@@ -489,19 +483,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0; ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0;
}); });
auto src_data_idx = container_reorder_given_old2new(ordered_idx, src_dim_access_order) * return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
src_vector_tensor_lengths; src_vector_tensor_lengths;
return src_data_idx;
}(); }();
// //
constexpr auto reset_src_data_step = [&]() { constexpr auto reset_src_data_step = [&]() {
Index reset_src_data_step; Index reset_src_data_step_;
static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step(i) = -src_data_idx[i]; }); static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; });
return reset_src_data_step; return reset_src_data_step_;
}(); }();
return reset_src_data_step; return reset_src_data_step;
...@@ -520,9 +512,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 ...@@ -520,9 +512,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
// judge move forward or move backward during the last iteration // judge move forward or move backward during the last iteration
constexpr auto forward_sweep = [&]() { constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep; StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep(I0) = true; forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) { static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_lengths[I0] - 1; index_t tmp = ordered_dst_access_lengths[I0] - 1;
...@@ -531,10 +523,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 ...@@ -531,10 +523,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1;
}); });
forward_sweep(i) = tmp % 2 == 0; forward_sweep_(i) = tmp % 2 == 0;
}); });
return forward_sweep; return forward_sweep_;
}(); }();
// calculate dst data index after last iteration in RunWrite(), if it has not being reset by // calculate dst data index after last iteration in RunWrite(), if it has not being reset by
...@@ -546,19 +538,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 ...@@ -546,19 +538,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0; ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0;
}); });
auto dst_data_idx = container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
dst_vector_tensor_lengths; dst_vector_tensor_lengths;
return dst_data_idx;
}(); }();
// //
constexpr auto reset_dst_data_step = [&]() { constexpr auto reset_dst_data_step = [&]() {
Index reset_dst_data_step; Index reset_dst_data_step_;
static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step(i) = -dst_data_idx[i]; }); static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; });
return reset_dst_data_step; return reset_dst_data_step_;
}(); }();
return reset_dst_data_step; return reset_dst_data_step;
...@@ -620,7 +610,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 ...@@ -620,7 +610,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize(); static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize();
StaticBuffer<AddressSpace::Vgpr, SrcData, buffer_size_> buffer_; StaticBuffer<AddressSpaceEnum_t::Vgpr, SrcData, buffer_size_> buffer_;
SrcCoord src_coord_; SrcCoord src_coord_;
DstCoord dst_coord_; DstCoord dst_coord_;
......
#ifndef CK_THREADWISE_GEMM_V3_HPP #ifndef CK_THREADWISE_GEMM_DLOPS_V3_HPP
#define CK_THREADWISE_GEMM_V3_HPP #define CK_THREADWISE_GEMM_DLOPS_V3_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "math.hpp" #include "math.hpp"
...@@ -22,7 +22,7 @@ template <typename FloatA, ...@@ -22,7 +22,7 @@ template <typename FloatA,
typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(), CDesc::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct ThreadwiseGemm_km_kn_mn_v3 struct ThreadwiseGemmDlops_km_kn_mn_v3
{ {
template <typename ABuffer, template <typename ABuffer,
typename AOriginIdx, typename AOriginIdx,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment