Commit 263c5e41 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent f63f1636
...@@ -66,8 +66,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -66,8 +66,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
// GEMM // GEMM
using GridwiseGemm = using GridwiseGemm =
...@@ -134,11 +132,11 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -134,11 +132,11 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor); using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor);
const auto GridSize = (M / MPerBlock) * (N / NPerBlock); const index_t grid_size = GridwiseGemm::CalculateGridSize(M, N);
const bool has_main_k_block_loop = (K + KPerBlock) / (2 * KPerBlock) > 1; const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K);
const bool has_double_tail_k_block_loop = (K / KPerBlock) % 2 == 0; const bool has_double_tail_k_block_loop = GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K);
float ave_time = 0; float ave_time = 0;
...@@ -156,7 +154,7 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -156,7 +154,7 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
dim3(GridSize), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
...@@ -182,7 +180,7 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -182,7 +180,7 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
dim3(GridSize), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
...@@ -208,7 +206,7 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -208,7 +206,7 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
dim3(GridSize), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
...@@ -234,7 +232,7 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -234,7 +232,7 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
dim3(GridSize), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
......
...@@ -31,8 +31,8 @@ template <index_t BlockSize, ...@@ -31,8 +31,8 @@ template <index_t BlockSize,
index_t DstScalarPerVector, index_t DstScalarPerVector,
index_t SrcScalarStrideInVector, index_t SrcScalarStrideInVector,
index_t DstScalarStrideInVector, index_t DstScalarStrideInVector,
index_t ThreadTransferSrcResetCoordinateAfterRun, bool ThreadTransferSrcResetCoordinateAfterRun,
index_t ThreadTransferDstResetCoordinateAfterRun> bool ThreadTransferDstResetCoordinateAfterRun>
struct BlockwiseDynamicTensorSliceTransfer_v4 struct BlockwiseDynamicTensorSliceTransfer_v4
{ {
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension(); static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
......
...@@ -126,6 +126,27 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -126,6 +126,27 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
return 2 * (a_block_space_size + b_block_space_size) * sizeof(FloatAB); return 2 * (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
} }
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
{
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
return grid_size;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const bool has_main_k_block_loop = (K + KPerBlock) / (2 * KPerBlock) > 1;
return has_main_k_block_loop;
}
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K)
{
const bool has_double_tail_k_block_loop = (K / KPerBlock) % 2 == 0;
return has_double_tail_k_block_loop;
}
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeAKM0M1GridDescriptor(const AKMGridDesc& a_k_m_grid_desc) MakeAKM0M1GridDescriptor(const AKMGridDesc& a_k_m_grid_desc)
{ {
......
...@@ -482,23 +482,17 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw( ...@@ -482,23 +482,17 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
const auto in_gemmk_gemmn_grid_desc = descs[I1]; const auto in_gemmk_gemmn_grid_desc = descs[I1];
const auto out_gemmm_gemmn_grid_desc = descs[I2]; const auto out_gemmm_gemmn_grid_desc = descs[I2];
// hack to control index calculation when iterating over wei_gemmk_gemmm_grid tensor // HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto wei_gemmk_gemmm_grid_iterator_hacks = constexpr auto wei_gemmk_gemmm_grid_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
constexpr auto wei_gemmk_gemmm_grid_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{};
// hack to control index calculation when iterating over in_gemmk_gemmn_grid tensor
constexpr auto in_gemmk_gemmn_grid_iterator_hacks = constexpr auto in_gemmk_gemmn_grid_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1>{}), Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2>{})); Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2>{}));
constexpr auto in_gemmk_gemmn_grid_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{};
constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks = constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
...@@ -513,6 +507,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw( ...@@ -513,6 +507,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
Sequence<0, 0, 2, 0, 0>{}, Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{})); Sequence<0, 0, 2, 0, 0>{}));
constexpr auto wei_gemmk_gemmm_grid_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{};
constexpr auto in_gemmk_gemmn_grid_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{};
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 5; ++i)
{ {
float ave_time = driver_dynamic_gemm_v1r2< float ave_time = driver_dynamic_gemm_v1r2<
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment