Commit 61faf02b authored by Chao Liu's avatar Chao Liu
Browse files

adding implicit GEMM v4r2

parent 1480375f
...@@ -11,7 +11,6 @@ ...@@ -11,7 +11,6 @@
namespace ck { namespace ck {
// define B = merge(N0, Ho, Wo)
template <index_t GridSize, template <index_t GridSize,
index_t BlockSize, index_t BlockSize,
class Float, class Float,
...@@ -182,6 +181,12 @@ struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer ...@@ -182,6 +181,12 @@ struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer
InBlockCopyDataPerAccess_W2>({0, 0, 0, 0, b_block_data_on_global, 0, 0, 0}, InBlockCopyDataPerAccess_W2>({0, 0, 0, 0, b_block_data_on_global, 0, 0, 0},
{0, 0, 0, 0, 0, 0, 0, 0}); {0, 0, 0, 0, 0, 0, 0, 0});
#if 1
{
printf("id (%d %d), in offset: %d %d\n", get_block_1d_id(), get_thread_local_1d_id(), blockwise_in_copy.mThreadSrcOffset, blockwise_in_copy.mThreadDstOffset);
}
#endif
// weight tensor // weight tensor
// tensor descriptor in device memory, src of blockwise copy // tensor descriptor in device memory, src of blockwise copy
constexpr auto wei_e_k_global_desc = constexpr auto wei_e_k_global_desc =
......
...@@ -53,15 +53,15 @@ void device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw(InDesc, ...@@ -53,15 +53,15 @@ void device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw(InDesc,
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 0 #if 1
// 1x1 filter, 8x8 image // 1x1 filter, 8x8 image
constexpr index_t N0 = 1; constexpr index_t N0 = 1;
constexpr index_t Ho0 = 1; constexpr index_t Ho0 = 2;
constexpr index_t Wo0 = 2; constexpr index_t Wo0 = 1;
constexpr index_t N2 = 1; constexpr index_t N2 = 4;
constexpr index_t Ho2 = 1; constexpr index_t Ho2 = 1;
constexpr index_t Wo2 = 4; constexpr index_t Wo2 = 1;
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
...@@ -79,8 +79,8 @@ void device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw(InDesc, ...@@ -79,8 +79,8 @@ void device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmDataPerReadA = 4; constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4; constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N0_Ho0_Wo0_B_N2_Ho2_Wo2 = Sequence<1, 1, 1, 1, 1, 1, 1, 4>; using InBlockCopySubLengths_E_N0_Ho0_Wo0_B_N2_Ho2_Wo2 = Sequence<1, 1, 1, 1, 1, 4, 1, 1>;
using InBlockCopyClusterLengths_E_N0_Ho0_Wo0_B_N2_Ho2_Wo2 = Sequence<8, 1, 1, 2, 16, 1, 1, 1>; using InBlockCopyClusterLengths_E_N0_Ho0_Wo0_B_N2_Ho2_Wo2 = Sequence<8, 1, 2, 1, 16, 1, 1, 1>;
using InBlockCopyThreadClusterArrangeOrder = using InBlockCopyThreadClusterArrangeOrder =
Sequence<0, 1, 5, 2, 6, 3, 4, 7>; // [E, N0, N2, Ho0, Ho2, Wo0, B, Wo2] Sequence<0, 1, 5, 2, 6, 3, 4, 7>; // [E, N0, N2, Ho0, Ho2, Wo0, B, Wo2]
using InBlockCopySrcAccessOrder = using InBlockCopySrcAccessOrder =
...@@ -88,7 +88,7 @@ void device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw(InDesc, ...@@ -88,7 +88,7 @@ void device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw(InDesc,
using InBlockCopyDstAccessOrder = using InBlockCopyDstAccessOrder =
Sequence<0, 1, 2, 3, 4, 5, 6, 7>; // [E, N0, Ho0, Wo0, B, N2, Ho2, Wo2] Sequence<0, 1, 2, 3, 4, 5, 6, 7>; // [E, N0, Ho0, Wo0, B, N2, Ho2, Wo2]
constexpr index_t InBlockCopyDataPerAccess_W2 = 4; constexpr index_t InBlockCopyDataPerAccess_W2 = 1;
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>; using WeiBlockCopySubLengths_E_K = Sequence<4, 1>;
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>; using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment