Commit 95a5af02 authored by Jing Zhang's avatar Jing Zhang
Browse files

ksub

parent 43fc4ce7
...@@ -132,22 +132,26 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 ...@@ -132,22 +132,26 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
constexpr auto EPerBlock = a_block_mtx.GetLength(I0); constexpr auto EPerBlock = a_block_mtx.GetLength(I0);
constexpr auto KPerThreadSubC = 4;
static_assert(KPerThread % KPerThreadSubC == 0, "");
// thread A, B for GEMM // thread A, B for GEMM
constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2( constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<EPerThreadLoop>{}, Number<KPerThread>{})); make_tuple(Number<EPerThreadLoop>{}, Number<KPerThreadSubC>{}));
constexpr auto b_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( constexpr auto b_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<EPerThreadLoop>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{})); Number<EPerThreadLoop>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
constexpr auto c_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( constexpr auto c_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<KPerThread>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{})); Number<KPerThreadSubC>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
FloatA p_a_thread[a_thread_mtx.GetElementSpaceSize()]; FloatA p_a_thread[a_thread_mtx.GetElementSpaceSize()];
constexpr auto a_thread_copy = ThreadwiseSliceCopy_a<BlockMatrixA, constexpr auto a_thread_copy = ThreadwiseSliceCopy_a<BlockMatrixA,
decltype(a_thread_mtx), decltype(a_thread_mtx),
EPerThreadLoop, EPerThreadLoop,
KPerThread, KPerThreadSubC,
ThreadGemmADataPerRead_K>{}; ThreadGemmADataPerRead_K>{};
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v3<decltype(a_thread_mtx), constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v3<decltype(a_thread_mtx),
...@@ -157,14 +161,20 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 ...@@ -157,14 +161,20 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
#pragma unroll #pragma unroll
for(index_t e_begin = 0; e_begin < EPerBlock; e_begin += EPerThreadLoop) for(index_t e_begin = 0; e_begin < EPerBlock; e_begin += EPerThreadLoop)
{ {
a_thread_copy.Run(p_a_block + a_block_mtx.CalculateOffset(make_tuple(e_begin, 0)) + #pragma unroll
for(index_t k_begin = 0; k_begin < KPerThread; k_begin += KPerThreadSubC)
{
a_thread_copy.Run(p_a_block +
a_block_mtx.CalculateOffset(make_tuple(e_begin, k_begin)) +
mMyThreadOffsetA, mMyThreadOffsetA,
p_a_thread); p_a_thread);
threadwise_gemm.Run(p_a_thread, threadwise_gemm.Run(
p_b_thread + p_a_thread,
b_thread_mtx.CalculateOffset(make_tuple(e_begin, 0, 0, 0)), p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(e_begin, 0, 0, 0)),
p_c_thread); p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(k_begin, 0, 0, 0)));
}
} }
} }
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#define CK_DEVICE_BACKEND_AMD 1 #define CK_DEVICE_BACKEND_AMD 1
// GPU ID // GPU ID
#if 1 #if 0
#define CK_AMD_GPU_GFX906 1 #define CK_AMD_GPU_GFX906 1
#elif 0 #elif 0
#define CK_AMD_GPU_GFX908 1 #define CK_AMD_GPU_GFX908 1
...@@ -85,7 +85,7 @@ ...@@ -85,7 +85,7 @@
// experimental implementation // experimental implementation
#ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK #ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0 #define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1
#endif #endif
#ifndef CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK #ifndef CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
......
...@@ -73,12 +73,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc ...@@ -73,12 +73,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
constexpr index_t KPerBlock = 16; constexpr index_t KPerBlock = 16;
constexpr index_t HoPerBlock = 16; constexpr index_t HoPerBlock = 16;
constexpr index_t WoPerBlock = 16; constexpr index_t WoPerBlock = 16;
constexpr index_t EPerBlock = 2; constexpr index_t EPerBlock = 4;
constexpr index_t KPerThread = 16; constexpr index_t KPerThread = 16;
constexpr index_t HoPerThread = 2; constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2; constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = 2; constexpr index_t EPerThread = 4;
using ABlockTransferThreadSliceLengths_E_K = Sequence<9, 1>; using ABlockTransferThreadSliceLengths_E_K = Sequence<9, 1>;
using ABlockTransferThreadClusterLengths_E_K = Sequence<4, 16>; using ABlockTransferThreadClusterLengths_E_K = Sequence<4, 16>;
...@@ -88,7 +88,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc ...@@ -88,7 +88,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
constexpr index_t BThreadTransferSrcScalarPerVector_W = 1; constexpr index_t BThreadTransferSrcScalarPerVector_W = 1;
constexpr index_t CThreadTransferDstScalarPerVector_W = 1; constexpr index_t CThreadTransferDstScalarPerVector_W = 2;
constexpr auto conv_driver = constexpr auto conv_driver =
DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad< DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad<
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment