Commit 4774d863 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 5dd45128
......@@ -98,9 +98,26 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
using CM0M1N0N1GridDesc = decltype(c_m0_m1_n0_n1_grid_desc);
#if 0
const auto c_m0_m10_m
#endif
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
constexpr auto M11 = Number<M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10>{};
constexpr auto N11 = Number<N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10>{};
constexpr auto M10 = M1 / M11;
constexpr auto N10 = N1 / N11;
const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_dynamic_tensor_descriptor(
c_m_n_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
make_unmerge_transform(make_tuple(N0, N10, N11))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc);
// out_gemm_block_cluster_desc
const auto c_block_cluster_desc =
......@@ -119,6 +136,7 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
BKNGridDesc,
CM0M1N0N1GridDesc,
CBlockClusterDesc,
CM0M10M11N0N10N11GridDesc,
MPerBlock,
NPerBlock,
KPerBlock,
......@@ -160,7 +178,6 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
const bool has_double_tail_k_block_loop = (K / KPerBlock) % 2 == 0;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
float ave_time = 0;
if(has_main_k_block_loop && has_double_tail_k_block_loop)
......@@ -173,6 +190,7 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
true,
true>;
......@@ -188,7 +206,8 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
a_k_m_grid_desc,
b_k_n_grid_desc,
c_m0_m1_n0_n1_grid_desc,
c_block_cluster_desc);
c_block_cluster_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
......@@ -200,6 +219,7 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
true,
false>;
......@@ -215,7 +235,8 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
a_k_m_grid_desc,
b_k_n_grid_desc,
c_m0_m1_n0_n1_grid_desc,
c_block_cluster_desc);
c_block_cluster_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
......@@ -227,6 +248,7 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
false,
true>;
......@@ -242,7 +264,8 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
a_k_m_grid_desc,
b_k_n_grid_desc,
c_m0_m1_n0_n1_grid_desc,
c_block_cluster_desc);
c_block_cluster_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc);
}
else
{
......@@ -254,6 +277,7 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
false,
false>;
......@@ -269,138 +293,11 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
a_k_m_grid_desc,
b_k_n_grid_desc,
c_m0_m1_n0_n1_grid_desc,
c_block_cluster_desc);
c_block_cluster_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc);
}
return ave_time;
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_k_m_grid_desc_device_buf(sizeof(AKMGridDesc));
DeviceMem b_k_n_grid_desc_device_buf(sizeof(BKNGridDesc));
DeviceMem c_m0_m1_n0_n1_grid_desc_device_buf(sizeof(CM0M1N0N1GridDesc));
DeviceMem c_block_cluster_desc_device_buf(sizeof(c_block_cluster_desc));
a_k_m_grid_desc_device_buf.ToDevice(&a_k_m_grid_desc);
b_k_n_grid_desc_device_buf.ToDevice(&b_k_n_grid_desc);
c_m0_m1_n0_n1_grid_desc_device_buf.ToDevice(&c_m0_m1_n0_n1_grid_desc);
c_block_cluster_desc_device_buf.ToDevice(&c_block_cluster_desc);
float ave_time = 0;
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1r2<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>,
true,
true>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
(void __CONSTANT__*)a_k_m_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1r2<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>,
true,
false>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
(void __CONSTANT__*)a_k_m_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1r2<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>,
false,
true>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
(void __CONSTANT__*)a_k_m_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
}
else
{
const auto kernel = kernel_dynamic_gemm_v1r2<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>,
false,
false>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
(void __CONSTANT__*)a_k_m_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
}
return ave_time;
#endif
}
} // namespace ck
......
......@@ -93,7 +93,8 @@ transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
// output tensor
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)),
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
make_tuple(make_pass_through_transform(K),
make_merge_transform(make_tuple(N, Ho * Wo))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......
......@@ -12,15 +12,15 @@
namespace ck {
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
template <typename GridwiseGemm,
typename FloatA,
typename FloatB,
typename FloatC,
typename AKMGridDesc,
typename BKNGridDesc,
typename CM0M1N0N1GridDesc,
typename CM10M11N10N11GridDesc,
typename CBlockClusterDesc,
typename CM0M10M11N0N10N11GridDesc,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
......@@ -32,69 +32,21 @@ __global__ void
FloatC* __restrict__ p_c_grid,
const AKMGridDesc a_k_m_grid_desc,
const BKNGridDesc b_k_n_grid_desc,
const CM0M1N0N1GridDesc c_m0_m1_n0_n1_grid_desc,
const CBlockClusterDesc c_block_cluster_desc)
const CM10M11N10N11GridDesc c_m10_n10_m11_n11_grid_desc,
const CBlockClusterDesc c_block_cluster_desc,
const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc)
{
GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_c_grid,
a_k_m_grid_desc,
b_k_n_grid_desc,
c_m0_m1_n0_n1_grid_desc,
c_m10_n10_m11_n11_grid_desc,
c_block_cluster_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
// pass tensor descriptor by __CONSTANT__ void pointer
// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to
// non-modifiable parameter address space, so compiler can enable corresponding optimization
template <typename GridwiseGemm,
typename FloatA,
typename FloatB,
typename FloatC,
typename AKMGridDesc,
typename BKNGridDesc,
typename CM0M1N0N1GridDesc,
typename CBlockClusterDesc,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_dynamic_gemm_v1r2(const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const void __CONSTANT__* p_a_k_m_grid_desc,
const void __CONSTANT__* p_b_k_n_grid_desc,
const void __CONSTANT__* p_c_m0_m1_n0_n1_grid_desc,
const void __CONSTANT__* p_c_block_cluster_desc)
{
// first cast void __CONSTANT__ void* to void*
// second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4)
const auto a_k_m_grid_desc =
*reinterpret_cast<const AKMGridDesc*>((const void*)p_a_k_m_grid_desc);
const auto b_k_n_grid_desc =
*reinterpret_cast<const BKNGridDesc*>((const void*)p_b_k_n_grid_desc);
const auto c_m0_m1_n0_n1_grid_desc =
*reinterpret_cast<const CM0M1N0N1GridDesc*>((const void*)p_c_m0_m1_n0_n1_grid_desc);
const auto c_block_cluster_desc =
*reinterpret_cast<const CBlockClusterDesc*>((const void*)p_c_block_cluster_desc);
GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_c_grid,
a_k_m_grid_desc,
b_k_n_grid_desc,
c_m0_m1_n0_n1_grid_desc,
c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
#endif
template <index_t BlockSize,
typename FloatAB,
......@@ -103,18 +55,19 @@ template <index_t BlockSize,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AKMGridDesc,
typename BKNGridDesc,
typename CM0M1N0N1GridDesc,
typename CM10M11N10N11GridDesc,
typename CBlockClusterDesc,
typename CM0M10M11N0N10N11GridDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t M1PerThread,
index_t N1PerThread,
index_t KPerThread,
index_t M1N1ThreadClusterM10,
index_t M1N1ThreadClusterN10,
index_t M1N1ThreadClusterM11,
index_t M1N1ThreadClusterN11,
index_t M1N1ThreadClusterM100,
index_t M1N1ThreadClusterN100,
index_t M1N1ThreadClusterM101,
index_t M1N1ThreadClusterN101,
typename ABlockTransferThreadSliceLengths_K_M,
typename ABlockTransferThreadClusterLengths_K_M,
typename ABlockTransferThreadClusterArrangeOrder,
......@@ -174,7 +127,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
}
__host__ __device__ static constexpr auto
MakeAKM0M1BlockClusterizedGridDescriptor(const AKMGridDesc& a_k_m_grid_desc)
MakeAKM0M1GridDescriptor(const AKMGridDesc& a_k_m_grid_desc)
{
const auto K = a_k_m_grid_desc.GetLength(I0);
const auto M = a_k_m_grid_desc.GetLength(I1);
......@@ -192,7 +145,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
}
__host__ __device__ static constexpr auto
MakeBKN0N1BlockClusterizedGridDescriptor(const BKNGridDesc& b_k_n_grid_desc)
MakeBKN0N1GridDescriptor(const BKNGridDesc& b_k_n_grid_desc)
{
const auto K = b_k_n_grid_desc.GetLength(I0);
const auto N = b_k_n_grid_desc.GetLength(I1);
......@@ -209,8 +162,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
return b_k_n0_n1_block_clusterized_grid_desc;
}
using AKM0M1GridDesc = decltype(MakeAKM0M1BlockClusterizedGridDescriptor(AKMGridDesc{}));
using BKN0N1GridDesc = decltype(MakeBKN0N1BlockClusterizedGridDescriptor(BKNGridDesc{}));
#if 0
__host__ __device__ static constexpr auto
MakeCM0M10M11N0N10N11GridDescriptor(const BKNGridDesc& b_k_n_grid_desc)
{
}
#endif
using AKM0M1GridDesc = decltype(MakeAKM0M1GridDescriptor(AKMGridDesc{}));
using BKN0N1GridDesc = decltype(MakeBKN0N1GridDescriptor(BKNGridDesc{}));
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
......@@ -218,8 +178,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
FloatC* __restrict__ p_c_grid,
const AKMGridDesc& a_k_m_grid_desc,
const BKNGridDesc& b_k_n_grid_desc,
const CM0M1N0N1GridDesc& c_m0_m1_n0_n1_grid_desc,
const CM10M11N10N11GridDesc& c_m10_n10_m11_n11_grid_desc,
const CBlockClusterDesc& c_block_cluster_desc,
const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc,
FloatAB* __restrict__ p_shared_block,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
......@@ -229,12 +190,18 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_b_grid, b_k_n_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpace::Global>(
p_c_grid, c_m0_m1_n0_n1_grid_desc.GetElementSpaceSize());
p_c_grid, c_m10_n10_m11_n11_grid_desc.GetElementSpaceSize());
const auto K = a_k_m_grid_desc.GetLength(I0);
const auto M = a_k_m_grid_desc.GetLength(I1);
const auto N = b_k_n_grid_desc.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
// divide block work by [M, N]
const auto block_work_idx =
c_block_cluster_desc.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
......@@ -332,17 +299,18 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM10,
M1N1ThreadClusterN10,
M1N1ThreadClusterM11,
M1N1ThreadClusterN11,
M1N1ThreadClusterM100,
M1N1ThreadClusterN100,
M1N1ThreadClusterM101,
M1N1ThreadClusterN101,
M1PerThread,
N1PerThread>{};
constexpr auto c_m0_m1_n0_n1_thread_tensor_lengths =
constexpr auto c_m10_n10_m11_n11_thread_tensor_lengths =
decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths();
constexpr auto c_m0_m1_n0_n1_thread_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
sequence_to_tuple_of_number(c_m0_m1_n0_n1_thread_tensor_lengths));
constexpr auto c_m10_n10_m11_n11_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(
sequence_to_tuple_of_number(c_m10_n10_m11_n11_thread_tensor_lengths));
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
......@@ -356,12 +324,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
// register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>(
c_m0_m1_n0_n1_thread_desc.GetElementSpaceSize());
c_m10_n10_m11_n11_thread_desc.GetElementSpaceSize());
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
decltype(c_m0_m1_n0_n1_thread_desc),
decltype(c_m0_m1_n0_n1_thread_tensor_lengths)>{}
.Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
decltype(c_m10_n10_m11_n11_thread_desc),
decltype(c_m10_n10_m11_n11_thread_tensor_lengths)>{}
.Run(c_m10_n10_m11_n11_thread_desc,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
FloatAcc{0});
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0);
......@@ -421,8 +392,10 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
b_k_n_grid_desc, b_global_buf, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(
c_m0_m1_n0_n1_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
blockwise_gemm.Run(c_m10_n10_m11_n11_thread_desc,
a_block_even_buf,
b_block_even_buf,
c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_odd_buf);
......@@ -446,7 +419,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(
c_m0_m1_n0_n1_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
c_m10_n10_m11_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_even_buf);
......@@ -474,7 +447,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(
c_m0_m1_n0_n1_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
c_m10_n10_m11_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_odd_buf);
......@@ -484,7 +457,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
c_m0_m1_n0_n1_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
c_m10_n10_m11_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
......@@ -492,42 +465,95 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
c_m0_m1_n0_n1_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
c_m10_n10_m11_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
}
// output: register to global memory
{
constexpr auto M1 = Number<M1PerThread * M1N1ThreadClusterM10 * M1N1ThreadClusterM11>{};
constexpr auto N1 = Number<N1PerThread * M1N1ThreadClusterN10 * M1N1ThreadClusterN11>{};
#if 0
constexpr auto M11 = Number<M1PerThread * M1N1ThreadClusterM100 * M1N1ThreadClusterM101>{};
constexpr auto N11 = Number<N1PerThread * M1N1ThreadClusterN100 * M1N1ThreadClusterN101>{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = CGridIteratorHacks{};
// hack to control index calculation when iterating over c_m10_n10_m11_n11_global tensor
constexpr auto c_m10_n10_m11_n11_global_tensor_iterator_hacks = CGridIteratorHacks{};
const auto c_thread_data_idx_on_block =
blockwise_gemm.CalculateCM0M1N0N1ThreadOriginIndex(get_thread_local_1d_id());
ThreadwiseDynamicTensorSliceTransfer_v1r3<FloatAcc,
FloatC,
decltype(c_m0_m1_n0_n1_thread_desc),
decltype(c_m0_m1_n0_n1_grid_desc),
decltype(c_m0_m1_n0_n1_thread_tensor_lengths),
decltype(c_m10_n10_m11_n11_thread_desc),
decltype(c_m10_n10_m11_n11_grid_desc),
decltype(c_m10_n10_m11_n11_thread_tensor_lengths),
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{
c_m0_m1_n0_n1_grid_desc,
make_multi_index(m_block_data_idx_on_global / M1 + c_thread_data_idx_on_block[I0],
c_m10_n10_m11_n11_grid_desc,
make_multi_index(m_block_data_idx_on_global / M11 + c_thread_data_idx_on_block[I0],
c_thread_data_idx_on_block[I1],
n_block_data_idx_on_global / N1 + c_thread_data_idx_on_block[I2],
n_block_data_idx_on_global / N11 + c_thread_data_idx_on_block[I2],
c_thread_data_idx_on_block[I3])}
.Run(c_m0_m1_n0_n1_thread_desc,
.Run(c_m10_n10_m11_n11_thread_desc,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
c_m0_m1_n0_n1_grid_desc,
c_m10_n10_m11_n11_grid_desc,
c_grid_buf,
c_m0_m1_n0_n1_global_tensor_iterator_hacks);
c_m10_n10_m11_n11_global_tensor_iterator_hacks);
#else
constexpr index_t M11 = M1PerThread * M1N1ThreadClusterM100 * M1N1ThreadClusterM101;
constexpr index_t N11 = N1PerThread * M1N1ThreadClusterN100 * M1N1ThreadClusterN101;
constexpr index_t M10 = MPerBlock / M11;
constexpr index_t N10 = NPerBlock / N11;
constexpr index_t M111 = M1PerThread;
constexpr index_t N111 = N1PerThread;
constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<1>{},
Number<c_m10_n10_m11_n11_thread_tensor_lengths[I0]>{},
Number<c_m10_n10_m11_n11_thread_tensor_lengths[I1]>{},
Number<1>{},
Number<c_m10_n10_m11_n11_thread_tensor_lengths[I2]>{},
Number<c_m10_n10_m11_n11_thread_tensor_lengths[I3]>{}));
const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
blockwise_gemm.CalculateCM0M1N0N1ThreadOriginIndex(get_thread_local_1d_id());
ThreadwiseDynamicTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_m0_m10_m11_n0_n10_n11_thread_desc),
decltype(c_m0_m10_m11_n0_n10_n11_grid_desc),
Sequence<1,
c_m10_n10_m11_n11_thread_tensor_lengths[I0],
c_m10_n10_m11_n11_thread_tensor_lengths[I1],
1,
c_m10_n10_m11_n11_thread_tensor_lengths[I2],
c_m10_n10_m11_n11_thread_tensor_lengths[I3]>,
Sequence<3, 4, 5, 0, 1, 2>, // TODO: CThreadTransferSrcDstAccessOrder
5, // TODO: CThreadTransferSrcDstVectorDim
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{c_m0_m10_m11_n0_n10_n11_grid_desc,
make_multi_index(__builtin_amdgcn_readfirstlane(block_work_idx[I0]),
c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
__builtin_amdgcn_readfirstlane(block_work_idx[I1]),
c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I3])}
.Run(c_m0_m10_m11_n0_n10_n11_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_m0_m10_m11_n0_n10_n11_grid_desc,
c_grid_buf,
CGridIteratorHacks{});
#endif
}
}
......@@ -537,8 +563,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
FloatC* __restrict__ p_c_grid,
const AKMGridDesc& a_k_m_grid_desc,
const BKNGridDesc& b_k_n_grid_desc,
const CM0M1N0N1GridDesc& c_m0_m1_n0_n1_grid_desc,
const CM10M11N10N11GridDesc& c_m10_n10_m11_n11_grid_desc,
const CBlockClusterDesc& c_block_cluster_desc,
const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
{
......@@ -551,8 +578,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
p_c_grid,
a_k_m_grid_desc,
b_k_n_grid_desc,
c_m0_m1_n0_n1_grid_desc,
c_m10_n10_m11_n11_grid_desc,
c_block_cluster_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc,
p_shared_block,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
......
......@@ -28,7 +28,7 @@
#endif
// launch bounds
#define CK_USE_LAUNCH_BOUNDS 1
#define CK_USE_LAUNCH_BOUNDS 0
#ifdef CK_USE_LAUNCH_BOUNDS
#define CK_MAX_THREAD_PER_BLOCK 256
......
......@@ -499,6 +499,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
constexpr auto in_gemmk_gemmn_grid_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{};
#if 0
// hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_grid
constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_grid_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
......@@ -509,6 +510,21 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
#else
constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
#endif
for(index_t i = 0; i < 5; ++i)
{
......@@ -553,7 +569,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
GemmCThreadTransferDstScalarPerVector_GemmN1,
decltype(wei_gemmk_gemmm_grid_iterator_hacks),
decltype(in_gemmk_gemmn_grid_iterator_hacks),
#if 0
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_grid_iterator_hacks),
#else
decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks),
#endif
decltype(wei_gemmk_gemmm_grid_move_slice_window_iterator_hacks),
decltype(in_gemmk_gemmn_grid_move_slice_window_iterator_hacks)>(
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
......@@ -566,7 +586,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
out_gemmm_gemmn_grid_desc,
wei_gemmk_gemmm_grid_iterator_hacks,
in_gemmk_gemmn_grid_iterator_hacks,
#if 0
out_gemmm0_gemmm1_gemmn0_gemmn1_grid_iterator_hacks,
#else
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks,
#endif
wei_gemmk_gemmm_grid_move_slice_window_iterator_hacks,
in_gemmk_gemmn_grid_move_slice_window_iterator_hacks,
nrepeat);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment