Commit 95a5af02 authored by Jing Zhang's avatar Jing Zhang
Browse files

ksub

parent 43fc4ce7
......@@ -132,22 +132,26 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
constexpr auto EPerBlock = a_block_mtx.GetLength(I0);
constexpr auto KPerThreadSubC = 4;
static_assert(KPerThread % KPerThreadSubC == 0, "");
// thread A, B for GEMM
constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<EPerThreadLoop>{}, Number<KPerThread>{}));
make_tuple(Number<EPerThreadLoop>{}, Number<KPerThreadSubC>{}));
constexpr auto b_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<EPerThreadLoop>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
constexpr auto c_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<KPerThread>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
Number<KPerThreadSubC>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
FloatA p_a_thread[a_thread_mtx.GetElementSpaceSize()];
constexpr auto a_thread_copy = ThreadwiseSliceCopy_a<BlockMatrixA,
decltype(a_thread_mtx),
EPerThreadLoop,
KPerThread,
KPerThreadSubC,
ThreadGemmADataPerRead_K>{};
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v3<decltype(a_thread_mtx),
......@@ -157,14 +161,20 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
#pragma unroll
for(index_t e_begin = 0; e_begin < EPerBlock; e_begin += EPerThreadLoop)
{
a_thread_copy.Run(p_a_block + a_block_mtx.CalculateOffset(make_tuple(e_begin, 0)) +
mMyThreadOffsetA,
p_a_thread);
threadwise_gemm.Run(p_a_thread,
p_b_thread +
b_thread_mtx.CalculateOffset(make_tuple(e_begin, 0, 0, 0)),
p_c_thread);
#pragma unroll
for(index_t k_begin = 0; k_begin < KPerThread; k_begin += KPerThreadSubC)
{
a_thread_copy.Run(p_a_block +
a_block_mtx.CalculateOffset(make_tuple(e_begin, k_begin)) +
mMyThreadOffsetA,
p_a_thread);
threadwise_gemm.Run(
p_a_thread,
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(e_begin, 0, 0, 0)),
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(k_begin, 0, 0, 0)));
}
}
}
......
......@@ -11,7 +11,7 @@
#define CK_DEVICE_BACKEND_AMD 1
// GPU ID
#if 1
#if 0
#define CK_AMD_GPU_GFX906 1
#elif 0
#define CK_AMD_GPU_GFX908 1
......@@ -85,7 +85,7 @@
// experimental implementation
#ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1
#endif
#ifndef CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
......
......@@ -73,12 +73,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
constexpr index_t KPerBlock = 16;
constexpr index_t HoPerBlock = 16;
constexpr index_t WoPerBlock = 16;
constexpr index_t EPerBlock = 2;
constexpr index_t EPerBlock = 4;
constexpr index_t KPerThread = 16;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = 2;
constexpr index_t EPerThread = 4;
using ABlockTransferThreadSliceLengths_E_K = Sequence<9, 1>;
using ABlockTransferThreadClusterLengths_E_K = Sequence<4, 16>;
......@@ -88,7 +88,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
constexpr index_t BThreadTransferSrcScalarPerVector_W = 1;
constexpr index_t CThreadTransferDstScalarPerVector_W = 1;
constexpr index_t CThreadTransferDstScalarPerVector_W = 2;
constexpr auto conv_driver =
DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad<
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment