Commit 4a661578 authored by Chao Liu's avatar Chao Liu
Browse files

updated block-cluster in gridwise gemm and thread-cluster in blockwise copy to...

updated block-cluster in gridwise gemm and thread-cluster in blockwise copy to use cluster descriptor
parent 8b306478
...@@ -151,6 +151,10 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad ...@@ -151,6 +151,10 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
// c_block_cluster_desc
const auto gemm_block_cluster_desc = make_cluster_descriptor_v2(
make_tuple(GemmM / Number<GemmMPerBlock>{}, GemmN / Number<GemmNPerBlock>{}));
// hack to control index calculation when iterating over a_k_m_global tensor // hack to control index calculation when iterating over a_k_m_global tensor
constexpr auto a_k_m_global_iterator_hacks = constexpr auto a_k_m_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
...@@ -190,6 +194,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad ...@@ -190,6 +194,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
decltype(wei_gemmk_gemmm_global_desc), decltype(wei_gemmk_gemmm_global_desc),
decltype(in_gemmk_gemmn_global_desc), decltype(in_gemmk_gemmn_global_desc),
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
decltype(gemm_block_cluster_desc),
GemmMPerBlock, GemmMPerBlock,
GemmNPerBlock, GemmNPerBlock,
GemmKPerBlock, GemmKPerBlock,
...@@ -255,6 +260,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad ...@@ -255,6 +260,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*, FloatC*,
decltype(gemm_block_cluster_desc),
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -269,6 +275,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad ...@@ -269,6 +275,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
p_in_global, p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global, p_out_global,
gemm_block_cluster_desc,
integral_constant<bool, true>{}, integral_constant<bool, true>{},
integral_constant<bool, true>{}); integral_constant<bool, true>{});
} }
...@@ -283,6 +290,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad ...@@ -283,6 +290,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*, FloatC*,
decltype(gemm_block_cluster_desc),
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
...@@ -297,6 +305,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad ...@@ -297,6 +305,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
p_in_global, p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global, p_out_global,
gemm_block_cluster_desc,
integral_constant<bool, true>{}, integral_constant<bool, true>{},
integral_constant<bool, false>{}); integral_constant<bool, false>{});
} }
...@@ -311,6 +320,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad ...@@ -311,6 +320,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*, FloatC*,
decltype(gemm_block_cluster_desc),
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -325,6 +335,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad ...@@ -325,6 +335,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
p_in_global, p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global, p_out_global,
gemm_block_cluster_desc,
integral_constant<bool, false>{}, integral_constant<bool, false>{},
integral_constant<bool, true>{}); integral_constant<bool, true>{});
} }
...@@ -339,6 +350,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad ...@@ -339,6 +350,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*, FloatC*,
decltype(gemm_block_cluster_desc),
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
...@@ -353,6 +365,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad ...@@ -353,6 +365,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
p_in_global, p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global, p_out_global,
gemm_block_cluster_desc,
integral_constant<bool, false>{}, integral_constant<bool, false>{},
integral_constant<bool, false>{}); integral_constant<bool, false>{});
} }
...@@ -525,6 +538,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad ...@@ -525,6 +538,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
} }
}; };
#if 0
// GemmM = K // GemmM = K
// GemmN = N * Ho * Wo // GemmN = N * Ho * Wo
// GemmK = C * Y * X // GemmK = C * Y * X
...@@ -1530,6 +1544,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1 ...@@ -1530,6 +1544,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
#endif #endif
} }
}; };
#endif
} // namespace ck } // namespace ck
#endif #endif
...@@ -48,16 +48,17 @@ __host__ __device__ constexpr auto make_cluster_descriptor( ...@@ -48,16 +48,17 @@ __host__ __device__ constexpr auto make_cluster_descriptor(
template <typename Lengths, template <typename Lengths,
typename ArrangeOrder = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type> typename ArrangeOrder = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type>
__host__ __device__ constexpr auto make_cluster_descriptor_v2( __host__ __device__ constexpr auto make_cluster_descriptor_v2(
Lengths, ArrangeOrder order = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type{}) const Lengths& lengths,
ArrangeOrder order = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type{})
{ {
constexpr auto reordered_lengths = Lengths::ReorderGivenNew2Old(ArrangeOrder{}); constexpr index_t ndim_low = Lengths::Size();
constexpr index_t ndim_low = reordered_lengths.Size(); const auto reordered_lengths = container_reorder_given_new2old(lengths, order);
constexpr auto low_lengths = generate_tuple( const auto low_lengths = generate_tuple(
[&](auto idim_low) { return Number<reordered_lengths[idim_low]>{}; }, Number<ndim_low>{}); [&](auto idim_low) { return reordered_lengths[idim_low]; }, Number<ndim_low>{});
constexpr auto transform = make_merge_transform(low_lengths); const auto transform = make_merge_transform(low_lengths);
constexpr auto low_dim_old_top_ids = ArrangeOrder{}; constexpr auto low_dim_old_top_ids = ArrangeOrder{};
......
...@@ -12,27 +12,6 @@ struct DynamicTensorCoordinate; ...@@ -12,27 +12,6 @@ struct DynamicTensorCoordinate;
template <index_t NTransform, index_t NDimVisible, typename UpdateLowerIndexHack> template <index_t NTransform, index_t NDimVisible, typename UpdateLowerIndexHack>
struct DynamicTensorCoordinateIterator; struct DynamicTensorCoordinateIterator;
#if 0
template <typename LowerDimensionIdss, typename UpperDimensionIdss>
__host__ __device__ constexpr index_t GetNumOfHiddenDimension(LowerDimensionIdss,
UpperDimensionIdss)
{
constexpr auto all_low_dim_ids =
unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionIdss{});
constexpr auto all_up_dim_ids =
unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionIdss{});
constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids);
using unique_sort_all_dim_ids = typename sequence_unique_sort<decltype(all_dim_ids),
math::less<index_t>,
math::equal<index_t>>::type;
return unique_sort_all_dim_ids::Size();
}
#endif
// Transforms: Tuple<transforms...> // Transforms: Tuple<transforms...>
// LowerDimensionIdss : Tuple<Sequence<...>, ...> // LowerDimensionIdss : Tuple<Sequence<...>, ...>
// UpperDimensionIdss : Tuple<Sequence<...>, ...> // UpperDimensionIdss : Tuple<Sequence<...>, ...>
......
...@@ -67,26 +67,18 @@ struct BlockwiseDynamicTensorSliceTransfer_v4 ...@@ -67,26 +67,18 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{ {
const auto thread_cluster_id = thread_cluster_desc_.CalculateBottomIndex( const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id())); make_multi_index(get_thread_local_1d_id()));
const auto thread_data_id_begin = thread_cluster_id * ThreadSliceLengths{}; const auto thread_data_idx_begin = thread_cluster_idx * ThreadSliceLengths{};
threadwise_transfer_.SetSrcSliceOrigin(src_desc, threadwise_transfer_.SetSrcSliceOrigin(src_desc,
src_block_slice_origin + thread_data_id_begin); src_block_slice_origin + thread_data_idx_begin);
threadwise_transfer_.SetDstSliceOrigin(dst_desc, threadwise_transfer_.SetDstSliceOrigin(dst_desc,
dst_block_slice_origin + thread_data_id_begin); dst_block_slice_origin + thread_data_idx_begin);
} }
} }
__device__ static constexpr auto CalculateThreadDataBegin()
{
const auto thread_cluster_id =
thread_cluster_desc_.CalculateClusterIndex(get_thread_local_1d_id());
return thread_cluster_id * ThreadSliceLengths{};
}
template <typename SrcIteratorHacks> template <typename SrcIteratorHacks>
__device__ void RunRead(const SrcDesc& src_desc, __device__ void RunRead(const SrcDesc& src_desc,
const SrcData* p_src, const SrcData* p_src,
...@@ -141,6 +133,7 @@ struct BlockwiseDynamicTensorSliceTransfer_v4 ...@@ -141,6 +133,7 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
} }
} }
private:
static constexpr auto thread_cluster_desc_ = static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor_v2(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor_v2(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
......
...@@ -47,44 +47,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -47,44 +47,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
index_t col; index_t col;
}; };
private:
static constexpr auto a_thread_mtx_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThreadLoop>{}, ThreadMatrixC{}.GetLength(Number<0>{})));
static constexpr auto b_thread_mtx_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThreadLoop>{}, ThreadMatrixC{}.GetLength(Number<1>{})));
using AThreadCopy =
ThreadwiseDynamicTensorSliceTransfer_v4<FloatA,
FloatA,
BlockMatrixA,
decltype(a_thread_mtx_desc_),
Sequence<KPerThreadLoop, MPerThreadSubC>,
Sequence<0, 1>,
1,
ThreadGemmADataPerRead_M,
AddressSpace::Generic,
AddressSpace::Vgpr,
1>;
using BThreadCopy =
ThreadwiseDynamicTensorSliceTransfer_v4<FloatB,
FloatB,
BlockMatrixB,
decltype(b_thread_mtx_desc_),
Sequence<KPerThreadLoop, NPerThreadSubC>,
Sequence<0, 1>,
1,
ThreadGemmBDataPerRead_N,
AddressSpace::Generic,
AddressSpace::Vgpr,
1>;
MatrixIndex c_thread_begin_mtx_idx_;
AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
public: public:
__device__ BlockwiseGemm_km_kn_m0m1n0n1_v1r1() __device__ BlockwiseGemm_km_kn_m0m1n0n1_v1r1()
: c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())}, : c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())},
...@@ -136,21 +98,20 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -136,21 +98,20 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
__device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) __device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id)
{ {
constexpr index_t ThreadPerLevel0Cluster = MLevel0ThreadCluster * NLevel0ThreadCluster; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
index_t level1_id = thread_id / ThreadPerLevel0Cluster; constexpr auto I2 = Number<2>{};
index_t level1_m_id = level1_id / NLevel1ThreadCluster; constexpr auto I3 = Number<3>{};
index_t level1_n_id = level1_id % NLevel1ThreadCluster;
index_t level0_id = thread_id % ThreadPerLevel0Cluster; const auto thread_cluster_idx =
index_t level0_m_id = level0_id / NLevel0ThreadCluster; c_thread_cluster_desc_.CalculateBottomIndex(make_multi_index(thread_id));
index_t level0_n_id = level0_id % NLevel0ThreadCluster;
constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0ThreadCluster; constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0ThreadCluster;
constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0ThreadCluster; constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0ThreadCluster;
return MatrixIndex{level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC, return MatrixIndex{
level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC}; thread_cluster_idx[I0] * MPerLevel0Cluster + thread_cluster_idx[I2] * MPerThreadSubC,
thread_cluster_idx[I1] * NPerLevel0Cluster + thread_cluster_idx[I3] * NPerThreadSubC};
} }
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer> template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
...@@ -371,6 +332,51 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -371,6 +332,51 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
Run_naive(a_block_buf, b_block_buf, c_thread_buf); Run_naive(a_block_buf, b_block_buf, c_thread_buf);
#endif #endif
} }
private:
static constexpr auto c_thread_cluster_desc_ =
make_cluster_descriptor_v2(Sequence<MLevel1ThreadCluster,
NLevel1ThreadCluster,
MLevel0ThreadCluster,
NLevel0ThreadCluster>{},
Sequence<0, 1, 2, 3>{});
static constexpr auto a_thread_mtx_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThreadLoop>{}, ThreadMatrixC{}.GetLength(Number<0>{})));
static constexpr auto b_thread_mtx_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThreadLoop>{}, ThreadMatrixC{}.GetLength(Number<1>{})));
using AThreadCopy =
ThreadwiseDynamicTensorSliceTransfer_v4<FloatA,
FloatA,
BlockMatrixA,
decltype(a_thread_mtx_desc_),
Sequence<KPerThreadLoop, MPerThreadSubC>,
Sequence<0, 1>,
1,
ThreadGemmADataPerRead_M,
AddressSpace::Generic,
AddressSpace::Vgpr,
1>;
using BThreadCopy =
ThreadwiseDynamicTensorSliceTransfer_v4<FloatB,
FloatB,
BlockMatrixB,
decltype(b_thread_mtx_desc_),
Sequence<KPerThreadLoop, NPerThreadSubC>,
Sequence<0, 1>,
1,
ThreadGemmBDataPerRead_N,
AddressSpace::Generic,
AddressSpace::Vgpr,
1>;
MatrixIndex c_thread_begin_mtx_idx_;
AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
}; };
} // namespace ck } // namespace ck
#endif #endif
...@@ -23,6 +23,7 @@ template <typename GridwiseGemm, ...@@ -23,6 +23,7 @@ template <typename GridwiseGemm,
typename FloatB, typename FloatB,
typename CGlobalDesc, typename CGlobalDesc,
typename FloatC, typename FloatC,
typename CBlockClusterDesc,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop> bool HasDoubleTailKBlockLoop>
__global__ void run_gridwise_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_global_desc, __global__ void run_gridwise_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_global_desc,
...@@ -30,9 +31,10 @@ __global__ void run_gridwise_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_gl ...@@ -30,9 +31,10 @@ __global__ void run_gridwise_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_gl
const void __CONSTANT__* p_b_k_n_global_desc, const void __CONSTANT__* p_b_k_n_global_desc,
const FloatB* __restrict__ p_b_global, const FloatB* __restrict__ p_b_global,
const void __CONSTANT__* p_c_m0_m1_n0_n1_global_desc, const void __CONSTANT__* p_c_m0_m1_n0_n1_global_desc,
FloatC* __restrict__ p_c_global) FloatC* __restrict__ p_c_global,
const void __CONSTANT__* p_c_block_cluster_desc)
{ {
// first cast void __CONSTANT__* to void* // first cast void __CONSTANT__ void* to void*
// second cast void* to Desc* // second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4) // the copy constructor of tensor descriptor doesn't take address_space(4)
const auto a_k_m_global_desc = const auto a_k_m_global_desc =
...@@ -42,12 +44,16 @@ __global__ void run_gridwise_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_gl ...@@ -42,12 +44,16 @@ __global__ void run_gridwise_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_gl
const auto c_m0_m1_n0_n1_global_desc = const auto c_m0_m1_n0_n1_global_desc =
*reinterpret_cast<const CGlobalDesc*>((const void*)p_c_m0_m1_n0_n1_global_desc); *reinterpret_cast<const CGlobalDesc*>((const void*)p_c_m0_m1_n0_n1_global_desc);
const auto c_block_cluster_desc =
*reinterpret_cast<const CBlockClusterDesc*>((const void*)p_c_block_cluster_desc);
GridwiseGemm{}.Run(a_k_m_global_desc, GridwiseGemm{}.Run(a_k_m_global_desc,
p_a_global, p_a_global,
b_k_n_global_desc, b_k_n_global_desc,
p_b_global, p_b_global,
c_m0_m1_n0_n1_global_desc, c_m0_m1_n0_n1_global_desc,
p_c_global, p_c_global,
c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailKBlockLoop>{});
} }
...@@ -61,6 +67,7 @@ template <index_t BlockSize, ...@@ -61,6 +67,7 @@ template <index_t BlockSize,
typename AGlobalDesc, typename AGlobalDesc,
typename BGlobalDesc, typename BGlobalDesc,
typename CGlobalDesc, typename CGlobalDesc,
typename CBlockClusterDesc,
index_t MPerBlock, index_t MPerBlock,
index_t NPerBlock, index_t NPerBlock,
index_t KPerBlock, index_t KPerBlock,
...@@ -131,6 +138,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -131,6 +138,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc, const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
const CBlockClusterDesc& c_block_cluster_desc,
FloatAB* __restrict__ p_shared_block, FloatAB* __restrict__ p_shared_block,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const integral_constant<bool, HasDoubleTailKBlockLoop>) const
...@@ -143,25 +151,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -143,25 +151,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
const auto N = b_k_n_global_desc.GetLength(I1); const auto N = b_k_n_global_desc.GetLength(I1);
// divide block work by [M, N] // divide block work by [M, N]
#if 0 const auto block_work_idx =
const auto m_block_work_num = M / Number<MPerBlock>{}; c_block_cluster_desc.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
const auto n_block_work_num = N / Number<NPerBlock>{};
const index_t m_block_work_id = get_block_1d_id() / n_block_work_num;
const index_t n_block_work_id = get_block_1d_id() - m_block_work_id * n_block_work_num;
#else // HACK: this force m/n_block_data_idx_on_global into SGPR
// Hack: this force result into SGPR const index_t m_block_data_idx_on_global =
const index_t m_block_work_num = __builtin_amdgcn_readfirstlane(M / MPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t n_block_work_num = __builtin_amdgcn_readfirstlane(N / NPerBlock);
const index_t m_block_work_id =
__builtin_amdgcn_readfirstlane(get_block_1d_id() / n_block_work_num);
const index_t n_block_work_id = get_block_1d_id() - m_block_work_id * n_block_work_num;
#endif
const index_t m_block_data_on_global = m_block_work_id * MPerBlock; const index_t n_block_data_idx_on_global =
const index_t n_block_data_on_global = n_block_work_id * NPerBlock; __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// lds max alignment // lds max alignment
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{}, constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{},
...@@ -204,7 +202,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -204,7 +202,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true>( true>(
a_k_m_global_desc, a_k_m_global_desc,
make_multi_index(0, m_block_data_on_global), make_multi_index(0, m_block_data_idx_on_global),
a_k_m_block_desc, a_k_m_block_desc,
make_multi_index(0, 0)); make_multi_index(0, 0));
...@@ -233,7 +231,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -233,7 +231,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true>( true>(
b_k_n_global_desc, b_k_n_global_desc,
make_multi_index(0, n_block_data_on_global), make_multi_index(0, n_block_data_idx_on_global),
b_k_n_block_desc, b_k_n_block_desc,
make_multi_index(0, 0)); make_multi_index(0, 0));
...@@ -441,10 +439,10 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -441,10 +439,10 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t m_thread_data_on_global = const index_t m_thread_data_on_global =
m_block_data_on_global + c_thread_mtx_on_block.row; m_block_data_idx_on_global + c_thread_mtx_on_block.row;
const index_t n_thread_data_on_global = const index_t n_thread_data_on_global =
n_block_data_on_global + c_thread_mtx_on_block.col; n_block_data_idx_on_global + c_thread_mtx_on_block.col;
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = CGlobalIteratorHacks{}; constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
...@@ -486,6 +484,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -486,6 +484,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc, const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
const CBlockClusterDesc& c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const integral_constant<bool, HasDoubleTailKBlockLoop>) const
{ {
...@@ -499,6 +498,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -499,6 +498,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
p_b_global, p_b_global,
c_m0_m1_n0_n1_global_desc, c_m0_m1_n0_n1_global_desc,
p_c_global, p_c_global,
c_block_cluster_desc,
p_shared_block, p_shared_block,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailKBlockLoop>{});
......
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
#include "device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" #include "device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp" #include "device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp" //#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp" //#include "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp"
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment