Commit 6c2a3a95 authored by Chao Liu's avatar Chao Liu
Browse files

Merge remote-tracking branch 'origin/dynamic_tensor_descriptor' into dynamic_tensor_descriptor_v5r1

parents f744524e 6bf99b9e
...@@ -3,10 +3,10 @@ ...@@ -3,10 +3,10 @@
template <typename GridwiseOp, typename... Xs> template <typename GridwiseOp, typename... Xs>
__global__ void __global__ void
#if 1 #if 0
__launch_bounds__(256, 2) __launch_bounds__(256, 2)
#endif #endif
run_gridwise_operation(Xs... xs) run_gridwise_operation(Xs... xs)
{ {
GridwiseOp{}.Run(xs...); GridwiseOp{}.Run(xs...);
} }
......
...@@ -91,7 +91,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -91,7 +91,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
const auto W = b_cyx_n_h_w_global_desc.GetLength(I3); const auto W = b_cyx_n_h_w_global_desc.GetLength(I3);
// divide block work by [M, N] // divide block work by [M, N]
#if 1 #if 0
const auto k_block_work_num = K / Number<KPerBlock>{}; const auto k_block_work_num = K / Number<KPerBlock>{};
const auto h_block_work_num = H / Number<HPerBlock>{}; const auto h_block_work_num = H / Number<HPerBlock>{};
const auto w_block_work_num = W / Number<WPerBlock>{}; const auto w_block_work_num = W / Number<WPerBlock>{};
...@@ -646,32 +646,33 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -646,32 +646,33 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
const index_t w_thread_data_on_global = w_block_data_on_global + w_thread_id * WPerThread; const index_t w_thread_data_on_global = w_block_data_on_global + w_thread_id * WPerThread;
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4< auto a_blockwise_copy =
BlockSize, BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
Sequence<CYX, K>, Sequence<CYX, K>,
ABlockTransferThreadSliceLengths_K_M, ABlockTransferThreadSliceLengths_K_M,
ABlockTransferThreadClusterLengths_K_M, ABlockTransferThreadClusterLengths_K_M,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
Float, Float,
Float, Float,
decltype(a_cyx_k_global_desc), decltype(a_cyx_k_global_desc),
decltype(a_cyx_k_desc), decltype(a_cyx_k_desc),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
Sequence<0, 1>, Sequence<0, 1>,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
1, 1,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M, ABlockTransferDstScalarPerVector_M,
AddressSpace::Global, AddressSpace::Global,
AddressSpace::Lds, AddressSpace::Lds,
1, 1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true>(a_cyx_k_global_desc, true>(
make_multi_index(0, k_block_data_on_global), a_cyx_k_global_desc,
a_cyx_k_desc, make_multi_index(0, k_block_data_on_global),
make_multi_index(0, 0)); a_cyx_k_desc,
make_multi_index(0, 0));
constexpr auto b_cyx_n_h_w_thread_desc = constexpr auto b_cyx_n_h_w_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
......
...@@ -37,7 +37,7 @@ ...@@ -37,7 +37,7 @@
#endif #endif
#ifndef CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM #ifndef CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
#define CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM 0 #define CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM 1
#endif #endif
#ifndef CK_USE_AMD_V_FMAC_F32 #ifndef CK_USE_AMD_V_FMAC_F32
...@@ -74,7 +74,7 @@ ...@@ -74,7 +74,7 @@
// experimental implementation // experimental implementation
#ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK #ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 #define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0
#endif #endif
#ifndef CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK #ifndef CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
......
...@@ -68,7 +68,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw( ...@@ -68,7 +68,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{}); const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{});
#endif #endif
#if 1 #if 0
// cdata = 16, BlockSize = 64, 16x64x4 // cdata = 16, BlockSize = 64, 16x64x4
constexpr index_t BlockSize = 64; constexpr index_t BlockSize = 64;
...@@ -101,6 +101,72 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw( ...@@ -101,6 +101,72 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 2; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 2;
#elif 0
// cdata = 32, BlockSize 64, 16x128x4
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 16;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 2;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t ThreadGemmDataPerReadM = 2;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
#elif 1
// cdata = 64, BlockSize 64, 16x256x4
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 16;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 1;
constexpr index_t GemmNLevel1Cluster = 16;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
#elif 0 #elif 0
// cdata = 16, BlockSize = 64, 16x64x4 // cdata = 16, BlockSize = 64, 16x64x4
// GemmBBlockCopySrcDataPerRead_GemmN = 4 // GemmBBlockCopySrcDataPerRead_GemmN = 4
......
...@@ -112,7 +112,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -112,7 +112,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
#if 1 #if 0
// cdata = 16, BlockSize = 64, 16x64x4 // cdata = 16, BlockSize = 64, 16x64x4
constexpr index_t BlockSize = 64; constexpr index_t BlockSize = 64;
...@@ -211,7 +211,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -211,7 +211,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
#elif 0 #elif 1
// cdata = 64, BlockSize = 64, 16x256x4 // cdata = 64, BlockSize = 64, 16x256x4
constexpr index_t BlockSize = 64; constexpr index_t BlockSize = 64;
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
#include "device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp" #include "device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp"
int main(int argc, char* argv[]) int main(int argc, char* argv[])
...@@ -49,8 +48,8 @@ int main(int argc, char* argv[]) ...@@ -49,8 +48,8 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 16; constexpr index_t C = 16;
...@@ -65,6 +64,20 @@ int main(int argc, char* argv[]) ...@@ -65,6 +64,20 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0
constexpr index_t N = 1;
constexpr index_t C = 4;
constexpr index_t HI = 1080;
constexpr index_t WI = 1920;
constexpr index_t K = 16;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 1 #elif 1
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 4; constexpr index_t C = 4;
...@@ -724,6 +737,23 @@ int main(int argc, char* argv[]) ...@@ -724,6 +737,23 @@ int main(int argc, char* argv[])
RightPads{}, RightPads{},
nrepeat); nrepeat);
#elif 1 #elif 1
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<in_data_t,
in_vector_size,
acc_data_t,
out_data_t>
(in_nchw_desc,
in_nchw,
wei_kcyx_desc,
wei_kcyx,
out_nkhw_desc,
out_nkhw_device,
ConvStrides{},
ConvDilations{},
LeftPads{},
RightPads{},
nrepeat);
#elif 1
device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(in_nchw_desc, device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw, in_nchw,
wei_kcyx_desc, wei_kcyx_desc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment