Commit 4774d863 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 5dd45128
......@@ -98,9 +98,26 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
using CM0M1N0N1GridDesc = decltype(c_m0_m1_n0_n1_grid_desc);
#if 0
const auto c_m0_m10_m
#endif
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
constexpr auto M11 = Number<M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10>{};
constexpr auto N11 = Number<N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10>{};
constexpr auto M10 = M1 / M11;
constexpr auto N10 = N1 / N11;
const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_dynamic_tensor_descriptor(
c_m_n_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
make_unmerge_transform(make_tuple(N0, N10, N11))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc);
// out_gemm_block_cluster_desc
const auto c_block_cluster_desc =
......@@ -119,6 +136,7 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
BKNGridDesc,
CM0M1N0N1GridDesc,
CBlockClusterDesc,
CM0M10M11N0N10N11GridDesc,
MPerBlock,
NPerBlock,
KPerBlock,
......@@ -160,7 +178,6 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
const bool has_double_tail_k_block_loop = (K / KPerBlock) % 2 == 0;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
float ave_time = 0;
if(has_main_k_block_loop && has_double_tail_k_block_loop)
......@@ -173,6 +190,7 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
true,
true>;
......@@ -188,7 +206,8 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
a_k_m_grid_desc,
b_k_n_grid_desc,
c_m0_m1_n0_n1_grid_desc,
c_block_cluster_desc);
c_block_cluster_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
......@@ -200,6 +219,7 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
true,
false>;
......@@ -215,7 +235,8 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
a_k_m_grid_desc,
b_k_n_grid_desc,
c_m0_m1_n0_n1_grid_desc,
c_block_cluster_desc);
c_block_cluster_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
......@@ -227,6 +248,7 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
false,
true>;
......@@ -242,7 +264,8 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
a_k_m_grid_desc,
b_k_n_grid_desc,
c_m0_m1_n0_n1_grid_desc,
c_block_cluster_desc);
c_block_cluster_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc);
}
else
{
......@@ -254,6 +277,7 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
false,
false>;
......@@ -269,138 +293,11 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
a_k_m_grid_desc,
b_k_n_grid_desc,
c_m0_m1_n0_n1_grid_desc,
c_block_cluster_desc);
c_block_cluster_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc);
}
return ave_time;
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_k_m_grid_desc_device_buf(sizeof(AKMGridDesc));
DeviceMem b_k_n_grid_desc_device_buf(sizeof(BKNGridDesc));
DeviceMem c_m0_m1_n0_n1_grid_desc_device_buf(sizeof(CM0M1N0N1GridDesc));
DeviceMem c_block_cluster_desc_device_buf(sizeof(c_block_cluster_desc));
a_k_m_grid_desc_device_buf.ToDevice(&a_k_m_grid_desc);
b_k_n_grid_desc_device_buf.ToDevice(&b_k_n_grid_desc);
c_m0_m1_n0_n1_grid_desc_device_buf.ToDevice(&c_m0_m1_n0_n1_grid_desc);
c_block_cluster_desc_device_buf.ToDevice(&c_block_cluster_desc);
float ave_time = 0;
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1r2<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>,
true,
true>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
(void __CONSTANT__*)a_k_m_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1r2<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>,
true,
false>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
(void __CONSTANT__*)a_k_m_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1r2<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>,
false,
true>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
(void __CONSTANT__*)a_k_m_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
}
else
{
const auto kernel = kernel_dynamic_gemm_v1r2<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>,
false,
false>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
(void __CONSTANT__*)a_k_m_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
}
return ave_time;
#endif
}
} // namespace ck
......
......@@ -93,7 +93,8 @@ transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
// output tensor
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)),
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
make_tuple(make_pass_through_transform(K),
make_merge_transform(make_tuple(N, Ho * Wo))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......
......@@ -28,7 +28,7 @@
#endif
// launch bounds
#define CK_USE_LAUNCH_BOUNDS 1
#define CK_USE_LAUNCH_BOUNDS 0
#ifdef CK_USE_LAUNCH_BOUNDS
#define CK_MAX_THREAD_PER_BLOCK 256
......
......@@ -499,6 +499,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
constexpr auto in_gemmk_gemmn_grid_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{};
#if 0
// hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_grid
constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_grid_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
......@@ -509,6 +510,21 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
#else
constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
#endif
for(index_t i = 0; i < 5; ++i)
{
......@@ -553,7 +569,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
GemmCThreadTransferDstScalarPerVector_GemmN1,
decltype(wei_gemmk_gemmm_grid_iterator_hacks),
decltype(in_gemmk_gemmn_grid_iterator_hacks),
#if 0
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_grid_iterator_hacks),
#else
decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks),
#endif
decltype(wei_gemmk_gemmm_grid_move_slice_window_iterator_hacks),
decltype(in_gemmk_gemmn_grid_move_slice_window_iterator_hacks)>(
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
......@@ -566,7 +586,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
out_gemmm_gemmn_grid_desc,
wei_gemmk_gemmm_grid_iterator_hacks,
in_gemmk_gemmn_grid_iterator_hacks,
#if 0
out_gemmm0_gemmm1_gemmn0_gemmn1_grid_iterator_hacks,
#else
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks,
#endif
wei_gemmk_gemmm_grid_move_slice_window_iterator_hacks,
in_gemmk_gemmn_grid_move_slice_window_iterator_hacks,
nrepeat);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment