Commit 937ad6c4 authored by root's avatar root
Browse files

inline, tuned

parent 5b242405
......@@ -165,7 +165,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
#if 1
// GEMM
using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v2<
using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v3<
BlockSize,
Float,
AccFloat,
......@@ -189,13 +189,13 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_GemmM,
false, // don't move back src coordinate after threadwise copy
Sequence<3, 2, 1, 0>,
Sequence<0, 2, 3, 1>,
3,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmN,
false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence<3, 2, 1, 0>,
Sequence<0, 2, 3, 1>,
3,
GemmCThreadTransferDstScalarPerVector_GemmN1,
decltype(a_k_m_global_iterator_hacks),
......@@ -224,34 +224,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
for(index_t j = 0; j < nrepeat; ++j)
{
#if 0
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
decltype(in_gemmk_n_ho_wo_global_desc),
const Float*,
decltype(out_gemmm_n_ho_wo_global_desc),
Float*,
integral_constant<bool, true>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_n_ho_wo_global_desc,
p_in_global,
out_gemmm_n_ho_wo_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
#else
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
......@@ -360,7 +332,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
#endif
}
timer.End();
......
......@@ -11,485 +11,6 @@
namespace ck {
template <index_t BlockSize,
typename Float,
typename AccFloat,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
index_t KPerBlock,
index_t HoPerBlock,
index_t WoPerBlock,
index_t EPerBlock,
index_t KPerThread,
index_t HoPerThread,
index_t WoPerThread,
index_t EPerThread,
typename ABlockTransferThreadSliceLengths_K_M,
typename ABlockTransferThreadClusterLengths_K_M,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_M,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_N,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGlobalIteratorHacks,
typename BGlobalIteratorHacks,
typename CGlobalIteratorHacks,
typename AGlobalMoveSliceWindowIteratorHacks,
typename BGlobalMoveSliceWindowIteratorHacks>
struct GridwiseDynamicGemm_km_kn_mn_v2
{
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr auto max_lds_align =
math::lcm(Number<ABlockTransferDstScalarPerVector_M>{}, Number<KPerBlock>{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_e_k_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<EPerBlock>{}, Number<KPerBlock>{}), max_lds_align);
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_e_k_block_desc.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_space_size) * sizeof(Float);
}
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_e_k_global_desc,
const Float* __restrict__ p_a_global,
const BGlobalDesc& b_e_n_ho_wo_global_desc,
const Float* __restrict__ p_b_global,
const CGlobalDesc& c_k_n_ho_wo_global_desc,
Float* __restrict__ p_c_global,
Float* __restrict__ p_shared_block,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto E = a_e_k_global_desc.GetLength(I0);
const auto K = a_e_k_global_desc.GetLength(I1);
const auto N = b_e_n_ho_wo_global_desc.GetLength(I1);
const auto Ho = b_e_n_ho_wo_global_desc.GetLength(I2);
const auto Wo = b_e_n_ho_wo_global_desc.GetLength(I3);
// divide block work by [M, N]
#if 0
const auto k_block_work_num = K / Number<KPerBlock>{};
const auto ho_block_work_num = Ho / Number<HoPerBlock>{};
const auto wo_block_work_num = Wo / Number<WoPerBlock>{};
const auto hwo_block_work_num = ho_block_work_num * wo_block_work_num;
const index_t k_block_work_id = get_block_1d_id() / hwo_block_work_num;
const index_t hwo_block_work_id = get_block_1d_id() - k_block_work_id * hwo_block_work_num;
const index_t ho_block_work_id = hwo_block_work_id / wo_block_work_num;
const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num;
#else
// Hack: this force result into SGPR
const index_t k_block_work_num = __builtin_amdgcn_readfirstlane(K / KPerBlock);
const index_t ho_block_work_num = __builtin_amdgcn_readfirstlane(Ho / HoPerBlock);
const index_t wo_block_work_num = __builtin_amdgcn_readfirstlane(Wo / WoPerBlock);
const index_t hwo_block_work_num = ho_block_work_num * wo_block_work_num;
const index_t k_block_work_id =
__builtin_amdgcn_readfirstlane(get_block_1d_id() / hwo_block_work_num);
const index_t hwo_block_work_id = get_block_1d_id() - k_block_work_id * hwo_block_work_num;
const index_t ho_block_work_id =
__builtin_amdgcn_readfirstlane(hwo_block_work_id / wo_block_work_num);
const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num;
#endif
// lds max alignment
constexpr auto max_lds_align =
math::lcm(Number<ABlockTransferDstScalarPerVector_M>{}, Number<KPerBlock>{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_e_k_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<EPerBlock>{}, Number<KPerBlock>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_e_n_ho_wo_block_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<EPerBlock>{}, Number<1>{}, Number<HoPerBlock>{}, Number<WoPerBlock>{}));
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_k_n_ho_wo_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<KPerThread>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
const auto blockwise_gemm =
BlockwiseGemm_km_kn_m0m1n0n1_v3<BlockSize,
decltype(a_e_k_block_desc),
decltype(b_e_n_ho_wo_block_desc),
decltype(c_k_n_ho_wo_thread_desc),
KPerThread, // KPerThreadSubC
HoPerThread, // HoPerThreadSubC
WoPerThread, // WoPerThreadSubC
EPerThread, // EPerThreadLoop
1, // ThreadGemmADataPerRead_K
1 // ThreadGemmBDataPerRead_W
>{};
auto c_thread_mtx_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const auto k_thread_id = c_thread_mtx_index.k;
const auto ho_thread_id = c_thread_mtx_index.h;
const auto wo_thread_id = c_thread_mtx_index.w;
const index_t k_block_data_on_global = k_block_work_id * KPerBlock;
const index_t ho_block_data_on_global = ho_block_work_id * HoPerBlock;
const index_t wo_block_data_on_global = wo_block_work_id * WoPerBlock;
const index_t ho_thread_data_on_global =
ho_block_data_on_global + ho_thread_id * HoPerThread;
const index_t wo_thread_data_on_global =
wo_block_data_on_global + wo_thread_id * WoPerThread;
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set,
Sequence<EPerBlock, KPerBlock>,
ABlockTransferThreadSliceLengths_K_M,
ABlockTransferThreadClusterLengths_K_M,
ABlockTransferThreadClusterArrangeOrder,
Float,
Float,
decltype(a_e_k_global_desc),
decltype(a_e_k_block_desc),
ABlockTransferSrcAccessOrder,
Sequence<0, 1>,
ABlockTransferSrcVectorDim,
1,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M,
AddressSpace::Global,
AddressSpace::Lds,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(
a_e_k_global_desc,
make_multi_index(0, k_block_data_on_global),
a_e_k_block_desc,
make_multi_index(0, 0));
constexpr auto b_e_n_ho_wo_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<EPerBlock>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
auto b_threadwise_transfer = ThreadwiseDynamicTensorSliceTransfer_v2<
Float,
Float,
decltype(b_e_n_ho_wo_global_desc),
decltype(b_e_n_ho_wo_thread_desc),
Sequence<EPerBlock, 1, HoPerThread, WoPerThread>,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
AddressSpace::Global,
AddressSpace::Vgpr,
InMemoryDataOperation::Set,
1,
true>(b_e_n_ho_wo_global_desc,
make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global));
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_e_k_block_desc.GetElementSpaceSize(), max_lds_align);
Float* p_a_block_double = p_shared_block;
// register allocation for output
AccFloat p_c_thread[c_k_n_ho_wo_thread_desc.GetElementSpaceSize()];
// zero out threadwise output
threadwise_matrix_set_zero_v3(c_k_n_ho_wo_thread_desc, p_c_thread);
constexpr auto a_block_slice_copy_step = make_multi_index(EPerBlock, 0);
constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k_m_global_iterator_hacks = AGlobalIteratorHacks{};
constexpr auto b_e_n_ho_wo_global_iterator_hacks = BGlobalIteratorHacks{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr auto a_k_m_global_move_slice_window_iterator_hack =
AGlobalMoveSliceWindowIteratorHacks{};
constexpr auto b_e_n_ho_wo_global_move_slice_window_iterator_hack =
BGlobalMoveSliceWindowIteratorHacks{};
constexpr auto b_thread_space_size = b_e_n_ho_wo_thread_desc.GetElementSpaceSize();
Float p_b_thread[b_thread_space_size * 2];
Float* p_b_thread_double = p_b_thread;
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.RunRead(a_e_k_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
p_b_global,
b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
p_b_thread_double,
b_e_n_ho_wo_global_iterator_hacks);
a_blockwise_copy.RunWrite(a_e_k_block_desc, p_a_block_double);
}
#if 1
if constexpr(HasMainKBlockLoop)
{
Float* p_a_block_even = p_a_block_double;
Float* p_a_block_odd = p_a_block_double + a_block_space_size;
Float* p_b_thread_even = p_b_thread_double;
Float* p_b_thread_odd = p_b_thread_double + b_thread_space_size;
index_t b_block_data_begin = 0;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_e_k_global_desc,
a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack);
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc,
b_thread_slice_copy_step);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_e_k_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
p_b_global,
b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
p_b_thread_odd,
b_e_n_ho_wo_global_iterator_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_even, p_b_thread_even, p_c_thread);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_e_k_block_desc, p_a_block_odd);
// odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_e_k_global_desc,
a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack);
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc,
b_thread_slice_copy_step);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_e_k_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
p_b_global,
b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
p_b_thread_even,
b_e_n_ho_wo_global_iterator_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_odd, p_b_thread_odd, p_c_thread);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_e_k_block_desc, p_a_block_even);
b_block_data_begin += 2 * EPerBlock;
} while(b_block_data_begin < E - 2 * EPerBlock);
}
// LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
a_blockwise_copy.MoveSrcSliceWindow(a_e_k_global_desc,
a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack);
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc,
b_thread_slice_copy_step);
__syncthreads();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_e_k_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
p_b_global,
b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
p_b_thread_double + b_thread_space_size,
b_e_n_ho_wo_global_iterator_hacks);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(p_a_block_double, p_b_thread_double, p_c_thread);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_e_k_block_desc, p_a_block_double + a_block_space_size);
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_a_block_double + a_block_space_size,
p_b_thread_double + b_thread_space_size,
p_c_thread);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_a_block_double, p_b_thread_double, p_c_thread);
}
#endif
#if 1
// output: register to global memory
{
// hack to control index calculation when iterating over c_k_n_ho_wo_global tensor
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
const index_t k_thread_data_on_global =
k_block_data_on_global + k_thread_id * KPerThread;
ThreadwiseDynamicTensorSliceTransfer_v1r3<
AccFloat,
Float,
decltype(c_k_n_ho_wo_thread_desc),
decltype(c_k_n_ho_wo_global_desc),
Sequence<KPerThread, 1, HoPerThread, WoPerThread>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AddressSpace::Vgpr,
AddressSpace::Global,
CGlobalMemoryDataOperation,
1,
true>(
c_k_n_ho_wo_global_desc,
make_multi_index(
k_thread_data_on_global, 0, ho_thread_data_on_global, wo_thread_data_on_global))
.Run(c_k_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
p_c_thread,
c_k_n_ho_wo_global_desc,
p_c_global,
c_k_n_ho_wo_global_tensor_iterator_hacks);
}
#endif
}
// pass tensor descriptor by reference
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_e_k_global_desc,
const Float* __restrict__ p_a_global,
const BGlobalDesc& b_e_n_ho_wo_global_desc,
const Float* __restrict__ p_b_global,
const CGlobalDesc& c_k_n_ho_wo_global_desc,
Float* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(Float);
__shared__ Float p_shared_block[shared_block_size];
Run(a_e_k_global_desc,
p_a_global,
b_e_n_ho_wo_global_desc,
p_b_global,
c_k_n_ho_wo_global_desc,
p_c_global,
p_shared_block,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
// pass tensor descriptors by their pointers
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc* p_a_e_k_global_desc,
const Float* __restrict__ p_a_global,
const BGlobalDesc* p_b_e_n_ho_wo_global_desc,
const Float* __restrict__ p_b_global,
const CGlobalDesc* p_c_k_n_ho_wo_global_desc,
Float* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
const auto a_e_k_global_desc = *p_a_e_k_global_desc;
const auto b_e_n_ho_wo_global_desc = *p_b_e_n_ho_wo_global_desc;
const auto c_k_n_ho_wo_global_desc = *p_c_k_n_ho_wo_global_desc;
Run(a_e_k_global_desc,
p_a_global,
b_e_n_ho_wo_global_desc,
p_b_global,
c_k_n_ho_wo_global_desc,
p_c_global,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
// pass tensor descriptors by void*
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const void* p_a_e_k_global_desc,
const Float* __restrict__ p_a_global,
const void* p_b_e_n_ho_wo_global_desc,
const Float* __restrict__ p_b_global,
const void* p_c_k_n_ho_wo_global_desc,
Float* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
const auto a_e_k_global_desc = *reinterpret_cast<const AGlobalDesc*>(p_a_e_k_global_desc);
const auto b_e_n_ho_wo_global_desc =
*reinterpret_cast<const BGlobalDesc*>(p_b_e_n_ho_wo_global_desc);
const auto c_k_n_ho_wo_global_desc =
*reinterpret_cast<const CGlobalDesc*>(p_c_k_n_ho_wo_global_desc);
Run(a_e_k_global_desc,
p_a_global,
b_e_n_ho_wo_global_desc,
p_b_global,
c_k_n_ho_wo_global_desc,
p_c_global,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
};
template <index_t BlockSize,
typename Float,
typename AccFloat,
......@@ -572,7 +93,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
const auto Wo = b_e_n_ho_wo_global_desc.GetLength(I3);
// divide block work by [M, N]
#if 1
#if 0
const auto k_block_work_num = K / Number<KPerBlock>{};
const auto ho_block_work_num = Ho / Number<HoPerBlock>{};
const auto wo_block_work_num = Wo / Number<WoPerBlock>{};
......@@ -625,13 +146,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
decltype(a_e_k_block_desc),
decltype(b_e_n_ho_wo_block_desc),
decltype(c_k_n_ho_wo_thread_desc),
KPerThread, // KPerThreadSubC
HoPerThread, // HoPerThreadSubC
WoPerThread, // WoPerThreadSubC
EPerThread, // EPerThreadLoop
1, // ThreadGemmADataPerRead_K
1 // ThreadGemmBDataPerRead_W
>{};
KPerThread,
HoPerThread,
WoPerThread,
EPerThread,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M>{};
auto c_thread_mtx_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
......@@ -687,9 +207,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
decltype(b_e_n_ho_wo_global_desc),
decltype(b_e_n_ho_wo_thread_desc),
Sequence<EPerBlock, 1, HoPerThread, WoPerThread>,
Sequence<3, 2, 0, 1>, // BBlockTransferSrcAccessOrder,
3, // BBlockTransferSrcVectorDim,
1, // BBlockTransferSrcScalarPerVector,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
AddressSpace::Global,
AddressSpace::Vgpr,
InMemoryDataOperation::Set,
......
......@@ -62,16 +62,68 @@ struct ThreadwiseGemm_km_kn_mn_v3
static_for<0, E, 1>{}([&](auto e) {
static_for<0, K, 1>{}([&](auto k) {
constexpr auto a_offset = ADesc{}.CalculateOffset(make_tuple(e, k));
if constexpr(H == 2 && W == 2)
{
constexpr auto b_offset_0 = BDesc{}.CalculateOffset(make_tuple(e, 0, 0, 0));
constexpr auto b_offset_1 = BDesc{}.CalculateOffset(make_tuple(e, 0, 0, 1));
constexpr auto b_offset_2 = BDesc{}.CalculateOffset(make_tuple(e, 0, 1, 0));
constexpr auto b_offset_3 = BDesc{}.CalculateOffset(make_tuple(e, 0, 1, 1));
constexpr auto c_offset_0 = CDesc{}.CalculateOffset(make_tuple(k, 0, 0, 0));
constexpr auto c_offset_1 = CDesc{}.CalculateOffset(make_tuple(k, 0, 0, 1));
constexpr auto c_offset_2 = CDesc{}.CalculateOffset(make_tuple(k, 0, 1, 0));
constexpr auto c_offset_3 = CDesc{}.CalculateOffset(make_tuple(k, 0, 1, 1));
amd_assembly_outer_product_1x4(p_a[a_offset],
p_b[b_offset_0],
p_b[b_offset_1],
p_b[b_offset_2],
p_b[b_offset_3],
p_c[c_offset_0],
p_c[c_offset_1],
p_c[c_offset_2],
p_c[c_offset_3]);
}
else if constexpr(H == 4 && W == 1)
{
constexpr auto b_offset_0 = BDesc{}.CalculateOffset(make_tuple(e, 0, 0, 0));
constexpr auto b_offset_1 = BDesc{}.CalculateOffset(make_tuple(e, 0, 1, 0));
constexpr auto b_offset_2 = BDesc{}.CalculateOffset(make_tuple(e, 0, 2, 0));
constexpr auto b_offset_3 = BDesc{}.CalculateOffset(make_tuple(e, 0, 3, 0));
constexpr auto c_offset_0 = CDesc{}.CalculateOffset(make_tuple(k, 0, 0, 0));
constexpr auto c_offset_1 = CDesc{}.CalculateOffset(make_tuple(k, 0, 1, 0));
constexpr auto c_offset_2 = CDesc{}.CalculateOffset(make_tuple(k, 0, 2, 0));
constexpr auto c_offset_3 = CDesc{}.CalculateOffset(make_tuple(k, 0, 3, 0));
amd_assembly_outer_product_1x4(p_a[a_offset],
p_b[b_offset_0],
p_b[b_offset_1],
p_b[b_offset_2],
p_b[b_offset_3],
p_c[c_offset_0],
p_c[c_offset_1],
p_c[c_offset_2],
p_c[c_offset_3]);
}
else
{
static_for<0, H, 1>{}([&](auto h) {
static_for<0, W, 1>{}([&](auto w) {
constexpr auto a_offset = ADesc{}.CalculateOffset(make_tuple(e, k));
constexpr auto b_offset = BDesc{}.CalculateOffset(make_tuple(e, 0, h, w));
constexpr auto c_offset = CDesc{}.CalculateOffset(make_tuple(k, 0, h, w));
constexpr auto b_offset =
BDesc{}.CalculateOffset(make_tuple(e, 0, h, w));
constexpr auto c_offset =
CDesc{}.CalculateOffset(make_tuple(k, 0, h, w));
p_c[c_offset] +=
inner_product_with_conversion<FloatC>{}(p_a[a_offset], p_b[b_offset]);
p_c[c_offset] += inner_product_with_conversion<FloatC>{}(p_a[a_offset],
p_b[b_offset]);
});
});
}
});
});
}
......
......@@ -85,7 +85,7 @@
// experimental implementation
#ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1
#endif
#ifndef CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
......
......@@ -68,19 +68,19 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
#endif
// cdata = 16, BlockSize = 64, 16x64x4
constexpr index_t BlockSize = 256;
constexpr index_t BlockSize = 64;
constexpr index_t KPerBlock = 16;
constexpr index_t HoPerBlock = 16;
constexpr index_t WoPerBlock = 16;
constexpr index_t EPerBlock = 4;
constexpr index_t EPerBlock = 2;
constexpr index_t KPerThread = 4;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t HoPerThread = 4;
constexpr index_t WoPerThread = 1;
constexpr index_t EPerThread = 2;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<9, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
......@@ -89,7 +89,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 2;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
constexpr auto conv_driver =
DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad<
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment