Commit 49b926b6 authored by Chao Liu's avatar Chao Liu
Browse files

overhauling fwd-v4r4

parent b8382727
...@@ -26,21 +26,21 @@ template <index_t BlockSize, ...@@ -26,21 +26,21 @@ template <index_t BlockSize,
index_t M1N1ThreadClusterN10, index_t M1N1ThreadClusterN10,
index_t M1N1ThreadClusterM11, index_t M1N1ThreadClusterM11,
index_t M1N1ThreadClusterN11, index_t M1N1ThreadClusterN11,
typename ABlockTransferThreadSliceLengths_K_M, typename ABlockTransferThreadSliceLengths_K_M0_M1,
typename ABlockTransferThreadClusterLengths_K_M, typename ABlockTransferThreadClusterLengths_K_M0_M1,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim, index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_M, index_t ABlockTransferDstScalarPerVector_M1,
bool AThreadTransferSrcResetCoordinateAfterRun, bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K_N, typename BBlockTransferThreadSliceLengths_K_N0_N1,
typename BBlockTransferThreadClusterLengths_K_N, typename BBlockTransferThreadClusterLengths_K_N0_N1,
typename BBlockTransferThreadClusterArrangeOrder, typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder, typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim, index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_N, index_t BBlockTransferDstScalarPerVector_N1,
bool BThreadTransferSrcResetCoordinateAfterRun, bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
...@@ -69,48 +69,48 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -69,48 +69,48 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
// GEMM // GEMM
using GridwiseGemm = using GridwiseGemm =
GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2<BlockSize, GridwiseDynamicGemm_km_kn_mn_v1r2<BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
FloatC, FloatC,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
AKMGridDesc, AKMGridDesc,
BKNGridDesc, BKNGridDesc,
CMNGridDesc, CMNGridDesc,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
M1PerThread, M1PerThread,
N1PerThread, N1PerThread,
KPerThread, KPerThread,
M1N1ThreadClusterM10, M1N1ThreadClusterM10,
M1N1ThreadClusterN10, M1N1ThreadClusterN10,
M1N1ThreadClusterM11, M1N1ThreadClusterM11,
M1N1ThreadClusterN11, M1N1ThreadClusterN11,
ABlockTransferThreadSliceLengths_K_M, ABlockTransferThreadSliceLengths_K_M0_M1,
ABlockTransferThreadClusterLengths_K_M, ABlockTransferThreadClusterLengths_K_M0_M1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M, ABlockTransferDstScalarPerVector_M1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K_N, BBlockTransferThreadSliceLengths_K_N0_N1,
BBlockTransferThreadClusterLengths_K_N, BBlockTransferThreadClusterLengths_K_N0_N1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N, BBlockTransferDstScalarPerVector_N1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
AGridIteratorHacks, AGridIteratorHacks,
BGridIteratorHacks, BGridIteratorHacks,
CGridIteratorHacks, CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks, AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks>; BGridMoveSliceWindowIteratorHacks>;
const auto M = a_k_m_grid_desc.GetLength(I1); const auto M = a_k_m_grid_desc.GetLength(I1);
const auto N = b_k_n_grid_desc.GetLength(I1); const auto N = b_k_n_grid_desc.GetLength(I1);
...@@ -118,8 +118,7 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -118,8 +118,7 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
if(!GridwiseGemm::CheckValidity(a_k_m_grid_desc, b_k_n_grid_desc, c_m_n_grid_desc)) if(!GridwiseGemm::CheckValidity(a_k_m_grid_desc, b_k_n_grid_desc, c_m_n_grid_desc))
{ {
throw std::runtime_error( throw std::runtime_error("wrong! GridwiseDynamicGemm_km_kn_mn_v1r2 has invalid setting");
"wrong! GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 has invalid setting");
} }
const auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc); const auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
...@@ -154,8 +153,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -154,8 +153,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
kernel_dynamic_gemm_v1r2<GridwiseGemm, kernel_dynamic_gemm_v1r2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>,
remove_reference_t<AKM0M1GridDesc>, remove_reference_t<AKM0M1GridDesc>,
remove_reference_t<BKN0N1GridDesc>, remove_reference_t<BKN0N1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
...@@ -172,8 +169,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -172,8 +169,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
a_k_m_grid_desc,
b_k_n_grid_desc,
a_k_m0_m1_grid_desc, a_k_m0_m1_grid_desc,
b_k_n0_n1_grid_desc, b_k_n0_n1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc, c_m0_m10_m11_n0_n10_n11_grid_desc,
...@@ -185,8 +180,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -185,8 +180,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
kernel_dynamic_gemm_v1r2<GridwiseGemm, kernel_dynamic_gemm_v1r2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>,
remove_reference_t<AKM0M1GridDesc>, remove_reference_t<AKM0M1GridDesc>,
remove_reference_t<BKN0N1GridDesc>, remove_reference_t<BKN0N1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
...@@ -203,8 +196,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -203,8 +196,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
a_k_m_grid_desc,
b_k_n_grid_desc,
a_k_m0_m1_grid_desc, a_k_m0_m1_grid_desc,
b_k_n0_n1_grid_desc, b_k_n0_n1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc, c_m0_m10_m11_n0_n10_n11_grid_desc,
...@@ -216,8 +207,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -216,8 +207,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
kernel_dynamic_gemm_v1r2<GridwiseGemm, kernel_dynamic_gemm_v1r2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>,
remove_reference_t<AKM0M1GridDesc>, remove_reference_t<AKM0M1GridDesc>,
remove_reference_t<BKN0N1GridDesc>, remove_reference_t<BKN0N1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
...@@ -234,8 +223,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -234,8 +223,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
a_k_m_grid_desc,
b_k_n_grid_desc,
a_k_m0_m1_grid_desc, a_k_m0_m1_grid_desc,
b_k_n0_n1_grid_desc, b_k_n0_n1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc, c_m0_m10_m11_n0_n10_n11_grid_desc,
...@@ -247,8 +234,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -247,8 +234,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
kernel_dynamic_gemm_v1r2<GridwiseGemm, kernel_dynamic_gemm_v1r2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>,
remove_reference_t<AKM0M1GridDesc>, remove_reference_t<AKM0M1GridDesc>,
remove_reference_t<BKN0N1GridDesc>, remove_reference_t<BKN0N1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
...@@ -265,8 +250,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -265,8 +250,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
a_k_m_grid_desc,
b_k_n_grid_desc,
a_k_m0_m1_grid_desc, a_k_m0_m1_grid_desc,
b_k_n0_n1_grid_desc, b_k_n0_n1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc, c_m0_m10_m11_n0_n10_n11_grid_desc,
......
...@@ -74,7 +74,7 @@ __host__ __device__ constexpr auto integer_divide_floor(X x, Y y) ...@@ -74,7 +74,7 @@ __host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
template <class X, class Y> template <class X, class Y>
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y) __host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
{ {
return (x + y - 1) / y; return (x + y - Number<1>{}) / y;
} }
template <class X, class Y> template <class X, class Y>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment