Commit 263c5e41 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent f63f1636
......@@ -66,8 +66,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
// GEMM
using GridwiseGemm =
......@@ -134,11 +132,11 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor);
const auto GridSize = (M / MPerBlock) * (N / NPerBlock);
const index_t grid_size = GridwiseGemm::CalculateGridSize(M, N);
const bool has_main_k_block_loop = (K + KPerBlock) / (2 * KPerBlock) > 1;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K);
const bool has_double_tail_k_block_loop = (K / KPerBlock) % 2 == 0;
const bool has_double_tail_k_block_loop = GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K);
float ave_time = 0;
......@@ -156,7 +154,7 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(GridSize),
dim3(grid_size),
dim3(BlockSize),
0,
0,
......@@ -182,7 +180,7 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(GridSize),
dim3(grid_size),
dim3(BlockSize),
0,
0,
......@@ -208,7 +206,7 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(GridSize),
dim3(grid_size),
dim3(BlockSize),
0,
0,
......@@ -234,7 +232,7 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(GridSize),
dim3(grid_size),
dim3(BlockSize),
0,
0,
......
......@@ -31,8 +31,8 @@ template <index_t BlockSize,
index_t DstScalarPerVector,
index_t SrcScalarStrideInVector,
index_t DstScalarStrideInVector,
index_t ThreadTransferSrcResetCoordinateAfterRun,
index_t ThreadTransferDstResetCoordinateAfterRun>
bool ThreadTransferSrcResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun>
struct BlockwiseDynamicTensorSliceTransfer_v4
{
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
......
......@@ -126,6 +126,27 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
return 2 * (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
}
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
{
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
return grid_size;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const bool has_main_k_block_loop = (K + KPerBlock) / (2 * KPerBlock) > 1;
return has_main_k_block_loop;
}
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K)
{
const bool has_double_tail_k_block_loop = (K / KPerBlock) % 2 == 0;
return has_double_tail_k_block_loop;
}
__host__ __device__ static constexpr auto
MakeAKM0M1GridDescriptor(const AKMGridDesc& a_k_m_grid_desc)
{
......
......@@ -482,23 +482,17 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
const auto in_gemmk_gemmn_grid_desc = descs[I1];
const auto out_gemmm_gemmn_grid_desc = descs[I2];
// hack to control index calculation when iterating over wei_gemmk_gemmm_grid tensor
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto wei_gemmk_gemmm_grid_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
constexpr auto wei_gemmk_gemmm_grid_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{};
// hack to control index calculation when iterating over in_gemmk_gemmn_grid tensor
constexpr auto in_gemmk_gemmn_grid_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2>{}));
constexpr auto in_gemmk_gemmn_grid_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{};
constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
......@@ -513,6 +507,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
constexpr auto wei_gemmk_gemmm_grid_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{};
constexpr auto in_gemmk_gemmn_grid_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{};
for(index_t i = 0; i < 5; ++i)
{
float ave_time = driver_dynamic_gemm_v1r2<
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment