Commit 49b926b6 authored by Chao Liu's avatar Chao Liu
Browse files

overhauling fwd-v4r4

parent b8382727
...@@ -26,21 +26,21 @@ template <index_t BlockSize, ...@@ -26,21 +26,21 @@ template <index_t BlockSize,
index_t M1N1ThreadClusterN10, index_t M1N1ThreadClusterN10,
index_t M1N1ThreadClusterM11, index_t M1N1ThreadClusterM11,
index_t M1N1ThreadClusterN11, index_t M1N1ThreadClusterN11,
typename ABlockTransferThreadSliceLengths_K_M, typename ABlockTransferThreadSliceLengths_K_M0_M1,
typename ABlockTransferThreadClusterLengths_K_M, typename ABlockTransferThreadClusterLengths_K_M0_M1,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim, index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_M, index_t ABlockTransferDstScalarPerVector_M1,
bool AThreadTransferSrcResetCoordinateAfterRun, bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K_N, typename BBlockTransferThreadSliceLengths_K_N0_N1,
typename BBlockTransferThreadClusterLengths_K_N, typename BBlockTransferThreadClusterLengths_K_N0_N1,
typename BBlockTransferThreadClusterArrangeOrder, typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder, typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim, index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_N, index_t BBlockTransferDstScalarPerVector_N1,
bool BThreadTransferSrcResetCoordinateAfterRun, bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
...@@ -69,48 +69,48 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -69,48 +69,48 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
// GEMM // GEMM
using GridwiseGemm = using GridwiseGemm =
GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2<BlockSize, GridwiseDynamicGemm_km_kn_mn_v1r2<BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
FloatC, FloatC,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
AKMGridDesc, AKMGridDesc,
BKNGridDesc, BKNGridDesc,
CMNGridDesc, CMNGridDesc,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
M1PerThread, M1PerThread,
N1PerThread, N1PerThread,
KPerThread, KPerThread,
M1N1ThreadClusterM10, M1N1ThreadClusterM10,
M1N1ThreadClusterN10, M1N1ThreadClusterN10,
M1N1ThreadClusterM11, M1N1ThreadClusterM11,
M1N1ThreadClusterN11, M1N1ThreadClusterN11,
ABlockTransferThreadSliceLengths_K_M, ABlockTransferThreadSliceLengths_K_M0_M1,
ABlockTransferThreadClusterLengths_K_M, ABlockTransferThreadClusterLengths_K_M0_M1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M, ABlockTransferDstScalarPerVector_M1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K_N, BBlockTransferThreadSliceLengths_K_N0_N1,
BBlockTransferThreadClusterLengths_K_N, BBlockTransferThreadClusterLengths_K_N0_N1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N, BBlockTransferDstScalarPerVector_N1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
AGridIteratorHacks, AGridIteratorHacks,
BGridIteratorHacks, BGridIteratorHacks,
CGridIteratorHacks, CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks, AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks>; BGridMoveSliceWindowIteratorHacks>;
const auto M = a_k_m_grid_desc.GetLength(I1); const auto M = a_k_m_grid_desc.GetLength(I1);
const auto N = b_k_n_grid_desc.GetLength(I1); const auto N = b_k_n_grid_desc.GetLength(I1);
...@@ -118,8 +118,7 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -118,8 +118,7 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
if(!GridwiseGemm::CheckValidity(a_k_m_grid_desc, b_k_n_grid_desc, c_m_n_grid_desc)) if(!GridwiseGemm::CheckValidity(a_k_m_grid_desc, b_k_n_grid_desc, c_m_n_grid_desc))
{ {
throw std::runtime_error( throw std::runtime_error("wrong! GridwiseDynamicGemm_km_kn_mn_v1r2 has invalid setting");
"wrong! GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 has invalid setting");
} }
const auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc); const auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
...@@ -154,8 +153,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -154,8 +153,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
kernel_dynamic_gemm_v1r2<GridwiseGemm, kernel_dynamic_gemm_v1r2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>,
remove_reference_t<AKM0M1GridDesc>, remove_reference_t<AKM0M1GridDesc>,
remove_reference_t<BKN0N1GridDesc>, remove_reference_t<BKN0N1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
...@@ -172,8 +169,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -172,8 +169,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
a_k_m_grid_desc,
b_k_n_grid_desc,
a_k_m0_m1_grid_desc, a_k_m0_m1_grid_desc,
b_k_n0_n1_grid_desc, b_k_n0_n1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc, c_m0_m10_m11_n0_n10_n11_grid_desc,
...@@ -185,8 +180,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -185,8 +180,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
kernel_dynamic_gemm_v1r2<GridwiseGemm, kernel_dynamic_gemm_v1r2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>,
remove_reference_t<AKM0M1GridDesc>, remove_reference_t<AKM0M1GridDesc>,
remove_reference_t<BKN0N1GridDesc>, remove_reference_t<BKN0N1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
...@@ -203,8 +196,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -203,8 +196,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
a_k_m_grid_desc,
b_k_n_grid_desc,
a_k_m0_m1_grid_desc, a_k_m0_m1_grid_desc,
b_k_n0_n1_grid_desc, b_k_n0_n1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc, c_m0_m10_m11_n0_n10_n11_grid_desc,
...@@ -216,8 +207,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -216,8 +207,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
kernel_dynamic_gemm_v1r2<GridwiseGemm, kernel_dynamic_gemm_v1r2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>,
remove_reference_t<AKM0M1GridDesc>, remove_reference_t<AKM0M1GridDesc>,
remove_reference_t<BKN0N1GridDesc>, remove_reference_t<BKN0N1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
...@@ -234,8 +223,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -234,8 +223,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
a_k_m_grid_desc,
b_k_n_grid_desc,
a_k_m0_m1_grid_desc, a_k_m0_m1_grid_desc,
b_k_n0_n1_grid_desc, b_k_n0_n1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc, c_m0_m10_m11_n0_n10_n11_grid_desc,
...@@ -247,8 +234,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -247,8 +234,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
kernel_dynamic_gemm_v1r2<GridwiseGemm, kernel_dynamic_gemm_v1r2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>,
remove_reference_t<AKM0M1GridDesc>, remove_reference_t<AKM0M1GridDesc>,
remove_reference_t<BKN0N1GridDesc>, remove_reference_t<BKN0N1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
...@@ -265,8 +250,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -265,8 +250,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
a_k_m_grid_desc,
b_k_n_grid_desc,
a_k_m0_m1_grid_desc, a_k_m0_m1_grid_desc,
b_k_n0_n1_grid_desc, b_k_n0_n1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc, c_m0_m10_m11_n0_n10_n11_grid_desc,
......
...@@ -15,8 +15,6 @@ namespace ck { ...@@ -15,8 +15,6 @@ namespace ck {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
typename AKMGridDesc,
typename BKNGridDesc,
typename AKM0M1GridDesc, typename AKM0M1GridDesc,
typename BKN0N1GridDesc, typename BKN0N1GridDesc,
typename CM0M10M11N0N10N11GridDesc, typename CM0M10M11N0N10N11GridDesc,
...@@ -31,8 +29,6 @@ __global__ void ...@@ -31,8 +29,6 @@ __global__ void
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const AKMGridDesc a_k_m_grid_desc,
const BKNGridDesc b_k_n_grid_desc,
const AKM0M1GridDesc a_k_m0_m1_grid_desc, const AKM0M1GridDesc a_k_m0_m1_grid_desc,
const BKN0N1GridDesc b_k_n0_n1_grid_desc, const BKN0N1GridDesc b_k_n0_n1_grid_desc,
const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc, const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc,
...@@ -47,8 +43,6 @@ __global__ void ...@@ -47,8 +43,6 @@ __global__ void
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
p_shared_block, p_shared_block,
a_k_m_grid_desc,
b_k_n_grid_desc,
a_k_m0_m1_grid_desc, a_k_m0_m1_grid_desc,
b_k_n0_n1_grid_desc, b_k_n0_n1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc, c_m0_m10_m11_n0_n10_n11_grid_desc,
...@@ -99,7 +93,7 @@ template <index_t BlockSize, ...@@ -99,7 +93,7 @@ template <index_t BlockSize,
typename CGridIteratorHacks, typename CGridIteratorHacks,
typename AGridMoveSliceWindowIteratorHacks, typename AGridMoveSliceWindowIteratorHacks,
typename BGridMoveSliceWindowIteratorHacks> typename BGridMoveSliceWindowIteratorHacks>
struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 struct GridwiseDynamicGemm_km_kn_mn_v1r2
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -124,13 +118,13 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -124,13 +118,13 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}), max_lds_align); make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}), max_lds_align);
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size = constexpr auto a_block_aligned_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size = constexpr auto b_block_aligned_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_space_size + b_block_space_size) * sizeof(FloatAB); return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
} }
__host__ __device__ static constexpr bool CheckValidity(const AKMGridDesc& a_k_m_grid_desc, __host__ __device__ static constexpr bool CheckValidity(const AKMGridDesc& a_k_m_grid_desc,
...@@ -178,13 +172,13 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -178,13 +172,13 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
const auto M1 = Number<MPerBlock>{}; const auto M1 = Number<MPerBlock>{};
const auto M0 = M / M1; const auto M0 = M / M1;
const auto a_k_m0_m1_block_clusterized_grid_desc = transform_dynamic_tensor_descriptor( const auto a_k_m0_m1_grid_desc = transform_dynamic_tensor_descriptor(
a_k_m_grid_desc, a_k_m_grid_desc,
make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(M0, M1))), make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(M0, M1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
return a_k_m0_m1_block_clusterized_grid_desc; return a_k_m0_m1_grid_desc;
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
...@@ -196,13 +190,13 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -196,13 +190,13 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
const auto N1 = Number<NPerBlock>{}; const auto N1 = Number<NPerBlock>{};
const auto N0 = N / N1; const auto N0 = N / N1;
const auto b_k_n0_n1_block_clusterized_grid_desc = transform_dynamic_tensor_descriptor( const auto b_k_n0_n1_grid_desc = transform_dynamic_tensor_descriptor(
b_k_n_grid_desc, b_k_n_grid_desc,
make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(N0, N1))), make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(N0, N1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
return b_k_n0_n1_block_clusterized_grid_desc; return b_k_n0_n1_grid_desc;
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
...@@ -246,7 +240,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -246,7 +240,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
const auto N0 = N / N1; const auto N0 = N / N1;
const auto c_blockid_to_m0_n0_block_cluster_adaptor = const auto c_blockid_to_m0_n0_block_cluster_adaptor =
make_cluster_descriptor_v2(make_tuple(M0, N0)); make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return c_blockid_to_m0_n0_block_cluster_adaptor; return c_blockid_to_m0_n0_block_cluster_adaptor;
} }
...@@ -263,8 +259,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -263,8 +259,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block, FloatAB* __restrict__ p_shared_block,
const AKMGridDesc& a_k_m_grid_desc,
const BKNGridDesc& b_k_n_grid_desc,
const AKM0M1GridDesc& a_k_m0_m1_grid_desc, const AKM0M1GridDesc& a_k_m0_m1_grid_desc,
const BKN0N1GridDesc& b_k_n0_n1_grid_desc, const BKN0N1GridDesc& b_k_n0_n1_grid_desc,
const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc, const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc,
...@@ -273,30 +267,22 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -273,30 +267,22 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
integral_constant<bool, HasDoubleTailKBlockLoop>) integral_constant<bool, HasDoubleTailKBlockLoop>)
{ {
const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>( const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_a_grid, a_k_m_grid_desc.GetElementSpaceSize()); p_a_grid, a_k_m0_m1_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>( const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_b_grid, b_k_n_grid_desc.GetElementSpaceSize()); p_b_grid, b_k_n0_n1_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpace::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpace::Global>(
p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize()); p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize());
const auto K = a_k_m_grid_desc.GetLength(I0); const auto K = a_k_m0_m1_grid_desc.GetLength(I0);
const auto M = a_k_m_grid_desc.GetLength(I1);
const auto N = b_k_n_grid_desc.GetLength(I1);
// divide block work by [M, N] // divide block work by [M, N]
const auto block_work_idx = c_blockid_to_m0_n0_block_cluster_adaptor.CalculateBottomIndex( const auto c_m0_n0_block_cluster_idx =
make_multi_index(get_block_1d_id())); c_blockid_to_m0_n0_block_cluster_adaptor.CalculateBottomIndex(
make_multi_index(get_block_1d_id()));
// HACK: this force index data into SGPR // HACK: this force index data into SGPR
const index_t m_block_work_idx = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); const index_t m0_idx = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]);
const index_t n0_idx = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]);
const index_t n_block_work_idx = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
const index_t m_block_data_idx_on_global =
__builtin_amdgcn_readfirstlane(m_block_work_idx * MPerBlock);
const index_t n_block_data_idx_on_global =
__builtin_amdgcn_readfirstlane(n_block_work_idx * NPerBlock);
// lds max alignment // lds max alignment
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{}, constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{},
...@@ -314,59 +300,67 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -314,59 +300,67 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}), max_lds_align); make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}), max_lds_align);
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m0_m1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, I1, Number<MPerBlock>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n0_n1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, I1, Number<NPerBlock>{}), max_lds_align);
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize, BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
Sequence<KPerBlock, MPerBlock>, Sequence<KPerBlock, 1, MPerBlock>,
ABlockTransferThreadSliceLengths_K_M, ABlockTransferThreadSliceLengths_K_M,
ABlockTransferThreadClusterLengths_K_M, ABlockTransferThreadClusterLengths_K_M,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(a_k_m_grid_desc), decltype(a_k_m0_m1_grid_desc),
decltype(a_k_m_block_desc), decltype(a_k_m0_m1_block_desc),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
Sequence<0, 1>, Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
1, 2,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M, ABlockTransferDstScalarPerVector_M,
1, 1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true>( true>(a_k_m0_m1_grid_desc,
a_k_m_grid_desc, make_multi_index(0, m0_idx, 0),
make_multi_index(0, m_block_data_idx_on_global), a_k_m0_m1_block_desc,
a_k_m_block_desc, make_multi_index(0, 0, 0));
make_multi_index(0, 0));
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = auto b_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize, BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
Sequence<KPerBlock, NPerBlock>, Sequence<KPerBlock, 1, NPerBlock>,
BBlockTransferThreadSliceLengths_K_N, BBlockTransferThreadSliceLengths_K_N,
BBlockTransferThreadClusterLengths_K_N, BBlockTransferThreadClusterLengths_K_N,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(b_k_n_grid_desc), decltype(b_k_n0_n1_grid_desc),
decltype(b_k_n_block_desc), decltype(b_k_n0_n1_block_desc),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
Sequence<0, 1>, Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
1, 2,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N, BBlockTransferDstScalarPerVector_N,
1, 1,
1, 1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true>( true>(b_k_n0_n1_grid_desc,
b_k_n_grid_desc, make_multi_index(0, n0_idx, 0),
make_multi_index(0, n_block_data_idx_on_global), b_k_n0_n1_block_desc,
b_k_n_block_desc, make_multi_index(0, 0, 0));
make_multi_index(0, 0));
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
...@@ -398,14 +392,14 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -398,14 +392,14 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
sequence_to_tuple_of_number(c_m10_n10_m11_n11_thread_tensor_lengths)); sequence_to_tuple_of_number(c_m10_n10_m11_n11_thread_tensor_lengths));
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size = constexpr auto a_block_aligned_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(a_k_m0_m1_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size = constexpr auto b_block_aligned_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(b_k_n0_n1_block_desc.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block; FloatAB* p_a_block_double = p_shared_block;
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_space_size; FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
// register allocation for output // register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>( auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>(
...@@ -419,37 +413,41 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -419,37 +413,41 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
c_thread_buf, c_thread_buf,
FloatAcc{0}); FloatAcc{0});
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy // hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k_m_global_iterator_hacks = AGridIteratorHacks{}; constexpr auto a_k_m0_m1_global_iterator_hacks = AGridIteratorHacks{};
constexpr auto b_k_n_global_iterator_hacks = BGridIteratorHacks{}; constexpr auto b_k_n0_n1_global_iterator_hacks = BGridIteratorHacks{};
// hack to control index calculation when move slice window for A and B matrix for // hack to control index calculation when move slice window for A and B matrix for
// threadwise copy // threadwise copy
constexpr auto a_k_m_global_move_slice_window_iterator_hack = constexpr auto a_k_m0_m1_global_move_slice_window_iterator_hack =
AGridMoveSliceWindowIteratorHacks{}; AGridMoveSliceWindowIteratorHacks{};
constexpr auto b_k_n_global_move_slice_window_iterator_hack = constexpr auto b_k_n0_n1_global_move_slice_window_iterator_hack =
BGridMoveSliceWindowIteratorHacks{}; BGridMoveSliceWindowIteratorHacks{};
auto a_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>( auto a_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_a_block_double, a_k_m_block_desc.GetElementSpaceSize()); p_a_block_double, a_k_m0_m1_block_desc.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>( auto b_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_b_block_double, b_k_n_block_desc.GetElementSpaceSize()); p_b_block_double, b_k_n0_n1_block_desc.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>( auto a_block_odd_buf =
p_a_block_double + a_block_space_size, a_k_m_block_desc.GetElementSpaceSize()); make_dynamic_buffer<AddressSpace::Lds>(p_a_block_double + a_block_aligned_space_size,
auto b_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>( a_k_m0_m1_block_desc.GetElementSpaceSize());
p_b_block_double + b_block_space_size, b_k_n_block_desc.GetElementSpaceSize()); auto b_block_odd_buf =
make_dynamic_buffer<AddressSpace::Lds>(p_b_block_double + b_block_aligned_space_size,
b_k_n0_n1_block_desc.GetElementSpaceSize());
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
a_blockwise_copy.RunRead(a_k_m_grid_desc, a_global_buf, a_k_m_global_iterator_hacks); a_blockwise_copy.RunRead(
b_blockwise_copy.RunRead(b_k_n_grid_desc, b_global_buf, b_k_n_global_iterator_hacks); a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks);
a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_even_buf); a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_even_buf); b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_even_buf);
} }
if constexpr(HasMainKBlockLoop) if constexpr(HasMainKBlockLoop)
...@@ -461,20 +459,22 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -461,20 +459,22 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
do do
{ {
// even iteration // even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_grid_desc, a_blockwise_copy.MoveSrcSliceWindow(
a_block_slice_copy_step, a_k_m0_m1_grid_desc,
a_k_m_global_move_slice_window_iterator_hack); a_block_slice_copy_step,
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_grid_desc, a_k_m0_m1_global_move_slice_window_iterator_hack);
b_block_slice_copy_step, b_blockwise_copy.MoveSrcSliceWindow(
b_k_n_global_move_slice_window_iterator_hack); b_k_n0_n1_grid_desc,
b_block_slice_copy_step,
b_k_n0_n1_global_move_slice_window_iterator_hack);
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_k_m_grid_desc, a_global_buf, a_k_m_global_iterator_hacks); a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks);
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(
b_k_n_grid_desc, b_global_buf, b_k_n_global_iterator_hacks); b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(c_m10_n10_m11_n11_thread_desc, blockwise_gemm.Run(c_m10_n10_m11_n11_thread_desc,
...@@ -483,32 +483,34 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -483,32 +483,34 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
c_thread_buf); c_thread_buf);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_odd_buf); a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_odd_buf); b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_odd_buf);
// odd iteration // odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_grid_desc, a_blockwise_copy.MoveSrcSliceWindow(
a_block_slice_copy_step, a_k_m0_m1_grid_desc,
a_k_m_global_move_slice_window_iterator_hack); a_block_slice_copy_step,
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_grid_desc, a_k_m0_m1_global_move_slice_window_iterator_hack);
b_block_slice_copy_step, b_blockwise_copy.MoveSrcSliceWindow(
b_k_n_global_move_slice_window_iterator_hack); b_k_n0_n1_grid_desc,
b_block_slice_copy_step,
b_k_n0_n1_global_move_slice_window_iterator_hack);
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_k_m_grid_desc, a_global_buf, a_k_m_global_iterator_hacks); a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks);
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(
b_k_n_grid_desc, b_global_buf, b_k_n_global_iterator_hacks); b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run( blockwise_gemm.Run(
c_m10_n10_m11_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); c_m10_n10_m11_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_even_buf); a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_even_buf); b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_even_buf);
k_block_data_begin += 2 * KPerBlock; k_block_data_begin += 2 * KPerBlock;
} while(k_block_data_begin < K - 2 * KPerBlock); } while(k_block_data_begin < K - 2 * KPerBlock);
...@@ -517,26 +519,28 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -517,26 +519,28 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
// LDS double buffer: tail // LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{ {
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_grid_desc, a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc,
a_block_slice_copy_step, a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack); a_k_m0_m1_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_grid_desc, b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc,
b_block_slice_copy_step, b_block_slice_copy_step,
b_k_n_global_move_slice_window_iterator_hack); b_k_n0_n1_global_move_slice_window_iterator_hack);
__syncthreads(); __syncthreads();
// LDS double buffer: load last data from device mem // LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_k_m_grid_desc, a_global_buf, a_k_m_global_iterator_hacks); a_blockwise_copy.RunRead(
b_blockwise_copy.RunRead(b_k_n_grid_desc, b_global_buf, b_k_n_global_iterator_hacks); a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks);
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run( blockwise_gemm.Run(
c_m10_n10_m11_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); c_m10_n10_m11_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS // LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_odd_buf); a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_odd_buf); b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_odd_buf);
__syncthreads(); __syncthreads();
...@@ -593,10 +597,10 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -593,10 +597,10 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
1, 1,
true>{c_m0_m10_m11_n0_n10_n11_grid_desc, true>{c_m0_m10_m11_n0_n10_n11_grid_desc,
make_multi_index(m_block_work_idx, make_multi_index(m0_idx,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I0], c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I1], c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
n_block_work_idx, n0_idx,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I2], c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I3])} c_m10_m11_n10_n11_thread_origin_idx_on_block[I3])}
.Run(c_m0_m10_m11_n0_n10_n11_thread_desc, .Run(c_m0_m10_m11_n0_n10_n11_thread_desc,
......
...@@ -74,7 +74,7 @@ __host__ __device__ constexpr auto integer_divide_floor(X x, Y y) ...@@ -74,7 +74,7 @@ __host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
template <class X, class Y> template <class X, class Y>
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y) __host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
{ {
return (x + y - 1) / y; return (x + y - Number<1>{}) / y;
} }
template <class X, class Y> template <class X, class Y>
......
...@@ -78,302 +78,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw( ...@@ -78,302 +78,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{}); const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{});
#endif #endif
#if 0 #if 1
// cdata = 16, BlockSize = 64, 16x64x4
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 16;
constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 2;
constexpr index_t GemmNPerThread = 2;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t ThreadGemmDataPerReadM = 2;
constexpr index_t ThreadGemmDataPerReadN = 2;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 2;
#elif 0
// cdata = 32, BlockSize 64, 16x128x4
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 16;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 2;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t ThreadGemmDataPerReadM = 2;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
#elif 0
// cdata = 64, BlockSize 64, 16x256x2
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 16;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 2;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 1;
constexpr index_t GemmNLevel1Cluster = 16;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 16>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
#elif 0
// cdata = 64, BlockSize 64, 16x256x4
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 16;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 1;
constexpr index_t GemmNLevel1Cluster = 16;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
#elif 0
// cdata = 16, BlockSize = 64, 16x64x4
// GemmBBlockCopySrcDataPerRead_GemmN = 4
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 2
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 16;
constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 2;
constexpr index_t GemmNPerThread = 2;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t ThreadGemmDataPerReadM = 2;
constexpr index_t ThreadGemmDataPerReadN = 2;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<4, 16>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 2;
#elif 0
// cdata = 32, BlockSize = 64, 16x128x4
// GemmBBlockCopySrcDataPerRead_GemmN = 4
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 4
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 16;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 2;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t ThreadGemmDataPerReadM = 2;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 32>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
#elif 0
// cdata = 64, BlockSize = 128, 32x256x8
constexpr index_t BlockSize = 128;
constexpr index_t GemmMPerBlock = 32;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 16;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 32>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<8, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 128>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#elif 0
// cdata = 64, BlockSize = 256, 128x128x2
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 2;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 8;
constexpr index_t GemmNLevel1Cluster = 8;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<1, 1>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#elif 0
// cdata = 64, BlockSize = 256, 128x128x4
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 8;
constexpr index_t GemmNLevel1Cluster = 8;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 2;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 1>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#elif 1
// cdata = 64, BlockSize = 256, 128x128x8 // cdata = 64, BlockSize = 256, 128x128x8
// b thread copy 4x1 // b thread copy 4x1
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
...@@ -391,82 +96,19 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw( ...@@ -391,82 +96,19 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
constexpr index_t GemmMLevel1Cluster = 8; constexpr index_t GemmMLevel1Cluster = 8;
constexpr index_t GemmNLevel1Cluster = 8; constexpr index_t GemmNLevel1Cluster = 8;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; using GemmABlockTransferThreadSliceLengths_GemmK_GemmM0_GemmM1 = Sequence<4, 1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>; using GemmABlockTransferThreadClusterLengths_GemmK_GemmM0_GemmM1 = Sequence<2, 1, 128>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#elif 1
// cdata = 64, BlockSize = 256, 128x128x8
// b thread copy 2x2
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 8;
constexpr index_t GemmNLevel1Cluster = 8;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 2;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<4, 64>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#elif 1
// cdata = 64, BlockSize = 256, 128x128x16
// GemmBBlockCopySrcDataPerRead_GemmN = 4
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 4
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM1 = 1;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>; using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN0_GemmN1 = Sequence<4, 1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>; using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN0_GemmN1 = Sequence<2, 1, 128>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN1 = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN1 = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN11 = 1;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
#endif #endif
const auto descs = const auto descs =
...@@ -483,15 +125,23 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw( ...@@ -483,15 +125,23 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
const auto out_gemmm_gemmn_grid_desc = descs[I2]; const auto out_gemmm_gemmn_grid_desc = descs[I2];
// HACK: hacks that control index calculation when iterating over A, B, C matrix // HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto wei_gemmk_gemmm_grid_iterator_hacks = constexpr auto wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}));
constexpr auto in_gemmk_gemmn_grid_iterator_hacks = constexpr auto in_gemmk_gemmn0_gemmn1_grid_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1>{}), Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2>{})); make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}));
constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks = constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
...@@ -507,10 +157,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw( ...@@ -507,10 +157,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
Sequence<0, 0, 2, 0, 0>{}, Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{})); Sequence<0, 0, 2, 0, 0>{}));
constexpr auto wei_gemmk_gemmm_grid_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{}; constexpr auto wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0>{};
constexpr auto in_gemmk_gemmn_grid_move_slice_window_iterator_hacks = constexpr auto in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 5; ++i)
{ {
...@@ -533,31 +184,30 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw( ...@@ -533,31 +184,30 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
GemmNLevel1Cluster, GemmNLevel1Cluster,
GemmMLevel0Cluster, GemmMLevel0Cluster,
GemmNLevel0Cluster, GemmNLevel0Cluster,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM, GemmABlockTransferThreadSliceLengths_GemmK_GemmM0_GemmM1,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM, GemmABlockTransferThreadClusterLengths_GemmK_GemmM0_GemmM1,
Sequence<1, 0>, Sequence<2, 1, 0>, // ABlockTransferThreadClusterArrangeOrder
Sequence<1, 0>, Sequence<2, 1, 0>, // ABlockTransferSrcAccessOrder
0, 0, // ABlockTransferSrcVectorDim
GemmABlockTransferSrcScalarPerVector_GemmK, GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_GemmM, GemmABlockTransferDstScalarPerVector_GemmM1,
false, // don't move back src coordinate after threadwise copy false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, GemmBBlockTransferThreadSliceLengths_GemmK_GemmN0_GemmN1,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, GemmBBlockTransferThreadClusterLengths_GemmK_GemmN0_GemmN1,
Sequence<0, 1>, Sequence<0, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
Sequence<0, 1>, Sequence<0, 1, 2>, // BBlockTransferSrcAccessOrder
1, 2, // BBlockTransferSrcVectorDim
GemmBBlockTransferSrcScalarPerVector_GemmN, GemmBBlockTransferSrcScalarPerVector_GemmN1,
GemmBBlockTransferDstScalarPerVector_GemmN, GemmBBlockTransferDstScalarPerVector_GemmN1,
false, // don't move back src coordinate after threadwise copy, which will be fused with false, // don't move back src coordinate after threadwise copy
// MoveSrcSliceWindow() to save addr computation
Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder
5, // CThreadTransferSrcDstVectorDim 5, // CThreadTransferSrcDstVectorDim
GemmCThreadTransferDstScalarPerVector_GemmN1, GemmCThreadTransferDstScalarPerVector_GemmN11,
decltype(wei_gemmk_gemmm_grid_iterator_hacks), decltype(wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks),
decltype(in_gemmk_gemmn_grid_iterator_hacks), decltype(in_gemmk_gemmn0_gemmn1_grid_iterator_hacks),
decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks), decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks),
decltype(wei_gemmk_gemmm_grid_move_slice_window_iterator_hacks), decltype(wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks),
decltype(in_gemmk_gemmn_grid_move_slice_window_iterator_hacks)>( decltype(in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks)>(
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>( static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
wei_k_c_y_x_device_buf.GetDeviceBuffer()), wei_k_c_y_x_device_buf.GetDeviceBuffer()),
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>( static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
...@@ -566,11 +216,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw( ...@@ -566,11 +216,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
wei_gemmk_gemmm_grid_desc, wei_gemmk_gemmm_grid_desc,
in_gemmk_gemmn_grid_desc, in_gemmk_gemmn_grid_desc,
out_gemmm_gemmn_grid_desc, out_gemmm_gemmn_grid_desc,
wei_gemmk_gemmm_grid_iterator_hacks, wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks,
in_gemmk_gemmn_grid_iterator_hacks, in_gemmk_gemmn0_gemmn1_grid_iterator_hacks,
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks, out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks,
wei_gemmk_gemmm_grid_move_slice_window_iterator_hacks, wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks,
in_gemmk_gemmn_grid_move_slice_window_iterator_hacks, in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks,
nrepeat); nrepeat);
float perf = (float)calculate_convolution_flops( float perf = (float)calculate_convolution_flops(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment