Unverified Commit 52c3fe05 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Refactor for MIOpen integration (#4)

Refactor, so can bring multi-index transformation and padding support into MIOpen
parent 9aaeacc8
...@@ -47,14 +47,17 @@ include_directories(BEFORE ...@@ -47,14 +47,17 @@ include_directories(BEFORE
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description ${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation ${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation
${PROJECT_SOURCE_DIR}/composable_kernel/include/kernel_algorithm ${PROJECT_SOURCE_DIR}/composable_kernel/include/kernel_algorithm
${PROJECT_SOURCE_DIR}/external/include
${PROJECT_SOURCE_DIR}/driver/include ${PROJECT_SOURCE_DIR}/driver/include
${PROJECT_BINARY_DIR}/composable_kernel/include/utility ${PROJECT_BINARY_DIR}/composable_kernel/include/utility
) )
if(DEVICE_BACKEND STREQUAL "AMD") if(DEVICE_BACKEND STREQUAL "AMD")
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/config_amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/config.hpp") configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/config.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/config.hpp")
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/float_type.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/float_type.hpp")
elseif(DEVICE_BACKEND STREQUAL "NVIDIA") elseif(DEVICE_BACKEND STREQUAL "NVIDIA")
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/config_nvidia.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/config.hpp") configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/config.nvidia.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/config.hpp")
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/float_type.nvidia.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/float_type.hpp")
endif() endif()
add_subdirectory(driver) add_subdirectory(driver)
#ifndef CK_CONVOLUTION_COMMON_HPP
#define CK_CONVOLUTION_COMMON_HPP
namespace ck {
enum ConvolutionDirection
{
Forward,
BackwardData,
BackwardWeight
};
} // namespace ck
#endif
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define CK_GRIDWISE_CONVOLUTION_DIRECT_V2_NCHW_KCYX_NKHW #define CK_GRIDWISE_CONVOLUTION_DIRECT_V2_NCHW_KCYX_NKHW
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor_deprecated.hpp"
#include "blockwise_2d_tensor_op.hpp" #include "blockwise_2d_tensor_op.hpp"
#include "blockwise_4d_tensor_op.hpp" #include "blockwise_4d_tensor_op.hpp"
#include "threadwise_tensor_slice_copy.hpp" #include "threadwise_tensor_slice_copy.hpp"
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R1_CHWN_CYXK_KHWN #define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R1_CHWN_CYXK_KHWN
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_4d_tensor_op.hpp" #include "blockwise_4d_tensor_op.hpp"
#include "blockwise_2d_tensor_op.hpp" #include "blockwise_2d_tensor_op.hpp"
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R2_CHWN_CYXK_KHWN #define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R2_CHWN_CYXK_KHWN
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_2d_tensor_op.hpp" #include "blockwise_2d_tensor_op.hpp"
#include "blockwise_3d_tensor_op.hpp" #include "blockwise_3d_tensor_op.hpp"
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_HPP #define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp" #include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp" #include "threadwise_generic_tensor_slice_copy.hpp"
...@@ -125,8 +125,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -125,8 +125,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
// blockwise copy // blockwise copy
// input: format is [C, Hi, Wi, N] // input: format is [C, Hi, Wi, N]
auto blockwise_in_copy = auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated<
BlockwiseGenericTensorSliceCopy_v1<BlockSize, BlockSize,
decltype(in_c_h_w_n_global_desc), decltype(in_c_h_w_n_global_desc),
decltype(in_c_h_w_n_block_desc), decltype(in_c_h_w_n_block_desc),
decltype(in_c_h_w_n_block_desc.GetLengths()), decltype(in_c_h_w_n_block_desc.GetLengths()),
...@@ -138,13 +138,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -138,13 +138,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
3, 3,
3, 3,
InBlockCopyDataPerAccess_N, InBlockCopyDataPerAccess_N,
InBlockCopyDataPerAccess_N>({0, 0, 0, 0}, InBlockCopyDataPerAccess_N>({0, 0, 0, 0}, {0, 0, 0, 0});
{0, 0, 0, 0});
// blockwise wei copy // blockwise wei copy
// format is [CPerBlock, X * KPerBlock] // format is [CPerBlock, X * KPerBlock]
const auto blockwise_wei_copy = const auto blockwise_wei_copy =
BlockwiseGenericTensorSliceCopy_v1<BlockSize, BlockwiseGenericTensorSliceCopy_v1_deprecated<BlockSize,
decltype(wei_c_k_global_desc), decltype(wei_c_k_global_desc),
decltype(wei_c_k_block_desc), decltype(wei_c_k_block_desc),
decltype(wei_c_k_block_desc.GetLengths()), decltype(wei_c_k_block_desc.GetLengths()),
...@@ -156,7 +155,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -156,7 +155,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
1, 1,
1, 1,
WeiBlockCopyDataPerAccess_K, WeiBlockCopyDataPerAccess_K,
WeiBlockCopyDataPerAccess_K>({0, 0}, {0, 0}); WeiBlockCopyDataPerAccess_K>({0, 0},
{0, 0});
// a series of blockwise batched GEMM // a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix // C_matrix += transpose(A_matrix) * B_matrix
...@@ -318,14 +318,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -318,14 +318,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
n_block_data_begin + n_thread_data_begin); n_block_data_begin + n_thread_data_begin);
#if 1 #if 1
ThreadwiseGenericTensorSliceCopy_v1r2<decltype(out_10d_thread_desc), ThreadwiseGenericTensorSliceCopy_v1r2_deprecated<
decltype(out_10d_thread_desc),
decltype(out_10d_global_desc), decltype(out_10d_global_desc),
decltype(out_10d_thread_desc.GetLengths()), decltype(out_10d_thread_desc.GetLengths()),
arithmetic_sequence_gen<0, 10, 1>::type, arithmetic_sequence_gen<0, 10, 1>::type,
9, 9,
OutThreadCopyDataPerAccess_N, OutThreadCopyDataPerAccess_N,
OutThreadCopyDataPerAccess_N>( OutThreadCopyDataPerAccess_N>(make_zero_array<index_t, 10>(),
make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>()) make_zero_array<index_t, 10>())
.Run(p_out_thread, p_out_thread_on_global); .Run(p_out_thread, p_out_thread_on_global);
#elif 0 #elif 0
ThreadwiseGenericTensorSliceCopy_v1r1<decltype(out_10d_thread_desc), ThreadwiseGenericTensorSliceCopy_v1r1<decltype(out_10d_thread_desc),
...@@ -388,14 +389,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -388,14 +389,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
n_block_data_begin + n_thread_data_begin); n_block_data_begin + n_thread_data_begin);
#if 1 #if 1
ThreadwiseGenericTensorSliceCopy_v1r2<decltype(out_10d_thread_desc), ThreadwiseGenericTensorSliceCopy_v1r2_deprecated<
decltype(out_10d_thread_desc),
decltype(out_10d_global_desc), decltype(out_10d_global_desc),
decltype(out_10d_thread_desc.GetLengths()), decltype(out_10d_thread_desc.GetLengths()),
arithmetic_sequence_gen<0, 10, 1>::type, arithmetic_sequence_gen<0, 10, 1>::type,
9, 9,
OutThreadCopyDataPerAccess_N, OutThreadCopyDataPerAccess_N,
OutThreadCopyDataPerAccess_N>( OutThreadCopyDataPerAccess_N>(make_zero_array<index_t, 10>(),
make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>()) make_zero_array<index_t, 10>())
.Run(p_out_thread, p_out_thread_on_global); .Run(p_out_thread, p_out_thread_on_global);
#elif 0 #elif 0
ThreadwiseGenericTensorSliceCopy_v1r1<decltype(out_10d_thread_desc), ThreadwiseGenericTensorSliceCopy_v1r1<decltype(out_10d_thread_desc),
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER_HPP #define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp" #include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp" #include "threadwise_generic_tensor_slice_copy.hpp"
...@@ -127,9 +127,9 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer ...@@ -127,9 +127,9 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
// input: format is [C, Hi, Wi, N] // input: format is [C, Hi, Wi, N]
auto blockwise_in_copy = auto blockwise_in_copy =
#if 0 #if 0
BlockwiseGenericTensorSliceCopy_v1 BlockwiseGenericTensorSliceCopy_v1_deprecated
#else #else
BlockwiseGenericTensorSliceCopy_v2 BlockwiseGenericTensorSliceCopy_v2_deprecated
#endif #endif
<BlockSize, <BlockSize,
decltype(in_c_h_w_n_global_desc), decltype(in_c_h_w_n_global_desc),
...@@ -149,9 +149,9 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer ...@@ -149,9 +149,9 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
// format is [CPerBlock, X * KPerBlock] // format is [CPerBlock, X * KPerBlock]
const auto blockwise_wei_copy = const auto blockwise_wei_copy =
#if 0 #if 0
BlockwiseGenericTensorSliceCopy_v1 BlockwiseGenericTensorSliceCopy_v1_deprecated
#else #else
BlockwiseGenericTensorSliceCopy_v2 BlockwiseGenericTensorSliceCopy_v2_deprecated
#endif #endif
<BlockSize, <BlockSize,
decltype(wei_c_k_global_desc), decltype(wei_c_k_global_desc),
...@@ -406,14 +406,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer ...@@ -406,14 +406,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
n_block_data_begin + n_thread_data_begin); n_block_data_begin + n_thread_data_begin);
#if 1 #if 1
ThreadwiseGenericTensorSliceCopy_v1r2<decltype(out_10d_thread_desc), ThreadwiseGenericTensorSliceCopy_v1r2_deprecated<
decltype(out_10d_thread_desc),
decltype(out_10d_global_desc), decltype(out_10d_global_desc),
decltype(out_10d_thread_desc.GetLengths()), decltype(out_10d_thread_desc.GetLengths()),
arithmetic_sequence_gen<0, 10, 1>::type, arithmetic_sequence_gen<0, 10, 1>::type,
9, 9,
OutThreadCopyDataPerAccess_N, OutThreadCopyDataPerAccess_N,
OutThreadCopyDataPerAccess_N>( OutThreadCopyDataPerAccess_N>(make_zero_array<index_t, 10>(),
make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>()) make_zero_array<index_t, 10>())
.Run(p_out_thread, p_out_thread_on_global); .Run(p_out_thread, p_out_thread_on_global);
#elif 0 #elif 0
ThreadwiseGenericTensorSliceCopy_v1r1<decltype(out_10d_thread_desc), ThreadwiseGenericTensorSliceCopy_v1r1<decltype(out_10d_thread_desc),
...@@ -476,14 +477,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer ...@@ -476,14 +477,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
n_block_data_begin + n_thread_data_begin); n_block_data_begin + n_thread_data_begin);
#if 1 #if 1
ThreadwiseGenericTensorSliceCopy_v1r2<decltype(out_10d_thread_desc), ThreadwiseGenericTensorSliceCopy_v1r2_deprecated<
decltype(out_10d_thread_desc),
decltype(out_10d_global_desc), decltype(out_10d_global_desc),
decltype(out_10d_thread_desc.GetLengths()), decltype(out_10d_thread_desc.GetLengths()),
arithmetic_sequence_gen<0, 10, 1>::type, arithmetic_sequence_gen<0, 10, 1>::type,
9, 9,
OutThreadCopyDataPerAccess_N, OutThreadCopyDataPerAccess_N,
OutThreadCopyDataPerAccess_N>( OutThreadCopyDataPerAccess_N>(make_zero_array<index_t, 10>(),
make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>()) make_zero_array<index_t, 10>())
.Run(p_out_thread, p_out_thread_on_global); .Run(p_out_thread, p_out_thread_on_global);
#elif 0 #elif 0
ThreadwiseGenericTensorSliceCopy_v1r1<decltype(out_10d_thread_desc), ThreadwiseGenericTensorSliceCopy_v1r1<decltype(out_10d_thread_desc),
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_PADDED_HPP #define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_PADDED_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_NCHW_CYXK_NKHW #define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_NCHW_CYXK_NKHW
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_2d_tensor_op.hpp" #include "blockwise_2d_tensor_op.hpp"
#include "blockwise_tensor_slice_copy.hpp" #include "blockwise_tensor_slice_copy.hpp"
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_NCHW_CYXK_NKHW_LDS_DOUBLE_BUFFER #define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_NCHW_CYXK_NKHW_LDS_DOUBLE_BUFFER
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_2d_tensor_op.hpp" #include "blockwise_2d_tensor_op.hpp"
#include "blockwise_tensor_slice_copy.hpp" #include "blockwise_tensor_slice_copy.hpp"
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V2_CHWN_CYXK_KHWN #define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V2_CHWN_CYXK_KHWN
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_4d_tensor_op.hpp" #include "blockwise_4d_tensor_op.hpp"
#include "blockwise_2d_tensor_op.hpp" #include "blockwise_2d_tensor_op.hpp"
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V2_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER #define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V2_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_4d_tensor_op.hpp" #include "blockwise_4d_tensor_op.hpp"
#include "blockwise_2d_tensor_op.hpp" #include "blockwise_2d_tensor_op.hpp"
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V3_NCHW_CYXK_NKHW #define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V3_NCHW_CYXK_NKHW
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMergedTensorDescriptor.hpp" #include "ConstantMergedTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp" #include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp" #include "blockwise_gemm.hpp"
...@@ -128,7 +128,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw ...@@ -128,7 +128,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
// input blockwise copy // input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor // slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in // this copy operator already has blockwise offset built-in
auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1< auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated<
BlockSize, BlockSize,
Float, Float,
decltype(in_c_n1_b_n2_global_merged_desc), decltype(in_c_n1_b_n2_global_merged_desc),
...@@ -155,8 +155,8 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw ...@@ -155,8 +155,8 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
// operator for blockwise copy of weight into LDS // operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor // slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in // this copy operator already have blockwise offset built-in
auto blockwise_wei_copy = auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated<
BlockwiseGenericTensorSliceCopy_v1<BlockSize, BlockSize,
Float, Float,
decltype(wei_c_k_global_desc), decltype(wei_c_k_global_desc),
decltype(wei_c_k_block_desc), decltype(wei_c_k_block_desc),
...@@ -167,8 +167,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw ...@@ -167,8 +167,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
Sequence<0, 1>, // src_access_order [C, K] Sequence<0, 1>, // src_access_order [C, K]
Sequence<0, 1>, // dst_access_order [C, K] Sequence<0, 1>, // dst_access_order [C, K]
WeiBlockCopyDataPerAccess_K, WeiBlockCopyDataPerAccess_K,
WeiBlockCopyDataPerAccess_K>( WeiBlockCopyDataPerAccess_K>({0, k_block_data_on_global}, {0, 0});
{0, k_block_data_on_global}, {0, 0});
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V3_NCHW_CYXK_NKHW_LDS_DOUBLE_BUFFER #define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V3_NCHW_CYXK_NKHW_LDS_DOUBLE_BUFFER
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMergedTensorDescriptor.hpp" #include "ConstantMergedTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp" #include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp" #include "blockwise_gemm.hpp"
...@@ -125,7 +125,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer ...@@ -125,7 +125,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer
// input blockwise copy // input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor // slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in // this copy operator already has blockwise offset built-in
const auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1< const auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated<
BlockSize, BlockSize,
Float, Float,
decltype(in_c_n1_b_n2_global_merged_desc), decltype(in_c_n1_b_n2_global_merged_desc),
...@@ -152,8 +152,8 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer ...@@ -152,8 +152,8 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer
// operator for blockwise copy of weight into LDS // operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor // slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in // this copy operator already have blockwise offset built-in
const auto blockwise_wei_copy = const auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated<
BlockwiseGenericTensorSliceCopy_v1<BlockSize, BlockSize,
Float, Float,
decltype(wei_c_k_global_desc), decltype(wei_c_k_global_desc),
decltype(wei_c_k_block_desc), decltype(wei_c_k_block_desc),
...@@ -164,8 +164,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer ...@@ -164,8 +164,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer
Sequence<0, 1>, // src_access_order [C, K] Sequence<0, 1>, // src_access_order [C, K]
Sequence<0, 1>, // dst_access_order [C, K] Sequence<0, 1>, // dst_access_order [C, K]
WeiBlockCopyDataPerAccess_K, WeiBlockCopyDataPerAccess_K,
WeiBlockCopyDataPerAccess_K>( WeiBlockCopyDataPerAccess_K>({0, k_block_data_on_global}, {0, 0});
{0, k_block_data_on_global}, {0, 0});
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R2_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER #define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R2_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMergedTensorDescriptor.hpp" #include "ConstantMergedTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp" #include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp" #include "blockwise_gemm.hpp"
...@@ -166,7 +166,7 @@ struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer ...@@ -166,7 +166,7 @@ struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer
// input blockwise copy // input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor // slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in // this copy operator already has blockwise offset built-in
auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1< auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated<
BlockSize, BlockSize,
Float, Float,
decltype(in_e_n0_ho0_wo0_b_n2_ho2_wo2_global_merged_desc), decltype(in_e_n0_ho0_wo0_b_n2_ho2_wo2_global_merged_desc),
...@@ -196,7 +196,7 @@ struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer ...@@ -196,7 +196,7 @@ struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer
// slice a tensor, and copy it into another tensor // slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in // this copy operator already have blockwise offset built-in
auto blockwise_wei_copy = auto blockwise_wei_copy =
BlockwiseGenericTensorSliceCopy_v1<BlockSize, BlockwiseGenericTensorSliceCopy_v1_deprecated<BlockSize,
Float, Float,
decltype(wei_e_k_global_desc), decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc), decltype(wei_e_k_block_desc),
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R3_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER #define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R3_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMergedTensorDescriptor.hpp" #include "ConstantMergedTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp" #include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp" #include "blockwise_gemm.hpp"
...@@ -165,7 +165,7 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer ...@@ -165,7 +165,7 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer
// input blockwise copy // input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor // slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in // this copy operator already has blockwise offset built-in
auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1< auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated<
BlockSize, BlockSize,
Float, Float,
decltype(in_e_n1_ho1_wo1_b_n2_ho2_wo2_global_merged_desc), decltype(in_e_n1_ho1_wo1_b_n2_ho2_wo2_global_merged_desc),
...@@ -195,7 +195,7 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer ...@@ -195,7 +195,7 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer
// slice a tensor, and copy it into another tensor // slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in // this copy operator already have blockwise offset built-in
auto blockwise_wei_copy = auto blockwise_wei_copy =
BlockwiseGenericTensorSliceCopy_v1<BlockSize, BlockwiseGenericTensorSliceCopy_v1_deprecated<BlockSize,
Float, Float,
decltype(wei_e_k_global_desc), decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc), decltype(wei_e_k_block_desc),
......
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_PADDED_LDS_DOUBLE_BUFFER_HPP #ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP_LDS_DOUBLE_BUFFER_DEPRECATRD_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_PADDED_LDS_DOUBLE_BUFFER_HPP #define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP_LDS_DOUBLE_BUFFER_DEPRECATRD_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_descriptor.hpp" #include "ConstantTensorDescriptor_deprecated.hpp"
#include "tensor_descriptor_helper.hpp" #include "ConstantMergedTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp" #include "blockwise_generic_tensor_slice_copy_deprecated.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp" #include "blockwise_gemm.hpp"
#include "threadwise_generic_tensor_slice_copy_deprecated.hpp"
namespace ck { namespace ck {
// B = merge(N, Ho, Wo) // B = merge(N, Ho, Wo)
template <index_t GridSize, template <index_t GridSize,
index_t BlockSize, index_t BlockSize,
typename Float, class Float,
typename InGlobalDesc, class InGlobalDesc,
typename WeiGlobalDesc, class WeiGlobalDesc,
typename OutGlobalDesc, class OutGlobalDesc,
typename ConvStrides, class ConvStrides,
typename ConvDilations, class ConvDilations,
typename LeftPads,
typename RightPads,
index_t BPerBlock, index_t BPerBlock,
index_t KPerBlock, index_t KPerBlock,
index_t EPerBlock, index_t EPerBlock,
...@@ -34,21 +32,21 @@ template <index_t GridSize, ...@@ -34,21 +32,21 @@ template <index_t GridSize,
index_t GemmKPerThreadLoop, index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA, index_t GemmDataPerReadA,
index_t GemmDataPerReadB, index_t GemmDataPerReadB,
typename InBlockCopySubLengths_E_B, class InBlockCopySubLengths_E_B,
typename InBlockCopyClusterLengths_E_B, class InBlockCopyClusterLengths_E_B,
typename InBlockCopyThreadClusterArrangeOrder, class InBlockCopyThreadClusterArrangeOrder,
typename InBlockCopySrcAccessOrder, class InBlockCopySrcAccessOrder,
typename InBlockCopyDstAccessOrder, class InBlockCopyDstAccessOrder,
index_t InBlockCopyDataPerAccess_B, index_t InBlockCopyDataPerAccess_B,
typename WeiBlockCopySubLengths_E_K, class WeiBlockCopySubLengths_E_K,
typename WeiBlockCopyClusterLengths_E_K, class WeiBlockCopyClusterLengths_E_K,
typename WeiBlockCopyThreadClusterArrangeOrder, class WeiBlockCopyThreadClusterArrangeOrder,
typename WeiBlockCopySrcAccessOrder, class WeiBlockCopySrcAccessOrder,
typename WeiBlockCopyDstAccessOrder, class WeiBlockCopyDstAccessOrder,
index_t WeiBlockCopySrcDataPerRead_E, index_t WeiBlockCopySrcDataPerRead_E,
index_t WeiBlockCopyDstDataPerWrite_K, index_t WeiBlockCopyDstDataPerWrite_K,
index_t OutThreadCopyDataPerAccess_B> index_t OutThreadCopyDataPerAccess_B>
struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buffer struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer_deprecated
{ {
__device__ void Run(const Float* const __restrict__ p_in_global, __device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global, const Float* const __restrict__ p_wei_global,
...@@ -58,27 +56,23 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf ...@@ -58,27 +56,23 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
constexpr auto I5 = Number<5>{};
constexpr auto True = integral_constant<bool, true>{}; constexpr auto True = integral_constant<bool, true>{};
constexpr auto in_n_c_hi_wi_global_desc = constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
make_native_tensor_descriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides()); constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
make_native_tensor_descriptor(WeiGlobalDesc::GetLengths(), WeiGlobalDesc::GetStrides());
constexpr auto out_n_k_ho_wo_global_desc =
make_native_tensor_descriptor(OutGlobalDesc::GetLengths(), OutGlobalDesc::GetStrides());
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLength(I0); constexpr index_t N = in_n_c_h_w_global_desc.GetLengths()[0];
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLength(I1); constexpr index_t C = in_n_c_h_w_global_desc.GetLengths()[1];
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLength(I1); constexpr index_t K = out_n_k_h_w_global_desc.GetLengths()[1];
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLength(I2); constexpr index_t Ho = out_n_k_h_w_global_desc.GetLengths()[2];
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLength(I3); constexpr index_t Wo = out_n_k_h_w_global_desc.GetLengths()[3];
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2); constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3); constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
constexpr index_t ConvStrideH = ConvStrides{}[0]; constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1]; constexpr index_t ConvStrideW = ConvStrides{}[1];
...@@ -103,45 +97,43 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf ...@@ -103,45 +97,43 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
constexpr index_t BBlockWork = B / BPerBlock; constexpr index_t BBlockWork = B / BPerBlock;
constexpr auto block_work_desc = constexpr auto block_work_desc =
make_cluster_descriptor(Sequence<KBlockWork, BBlockWork>{}); make_ConstantTensorDescriptor_packed(Sequence<KBlockWork, BBlockWork>{});
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id()); const auto block_work_multi_id =
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
const index_t k_block_data_on_global = block_work_id[0] * KPerBlock; const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock;
const index_t b_block_data_on_global = block_work_id[1] * BPerBlock; const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
// input tensor // input tensor
// global mem // tensor descriptor in device memory [N, Ho, Wo]
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( constexpr auto in_n_ho_wo_global_desc =
in_n_c_hi_wi_global_desc, in_n_c_h_w_global_desc.Extract(I0, I2, I3)
make_tuple( .StridedSlice(I1, Number<Ho>{}, Number<ConvStrideH>{})
PassThrough<N>{}, PassThrough<C>{}, Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}), .StridedSlice(I2, Number<Wo>{}, Number<ConvStrideW>{});
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{})); // batch descritpor for device memory
constexpr auto in_c_y_x_global_desc =
constexpr auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor( in_n_c_h_w_global_desc.StridedSlice(I2, Number<Y>{}, Number<ConvDilationH>{})
in_n_c_hip_wip_global_desc, .StridedSlice(I3, Number<X>{}, Number<ConvDilationW>{})
make_tuple(PassThrough<N>{}, .Extract(Sequence<1, 2, 3>{});
PassThrough<C>{},
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{}, // merged tensor descriptor in device memory [E, B], src of blockwise copy
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}), constexpr auto in_e_b_global_desc =
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_ConstantMergedTensorDescriptor(in_c_y_x_global_desc.Embed(in_n_ho_wo_global_desc),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); Sequence<0, 1, 2>{},
Sequence<3, 4, 5>{});
constexpr auto in_e_b_global_desc = transform_tensor_descriptor(
in_n_c_y_ho_x_wo_global_desc, // memory layout descriptor in LDS [E, B], dst of blockwise copy
make_tuple(Merge<Sequence<C, Y, X>>{}, Merge<Sequence<N, Ho, Wo>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// LDS mem
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto in_e_b_block_desc = constexpr auto in_e_b_block_desc =
make_native_tensor_descriptor_packed(Sequence<EPerBlock, BPerBlock>{}); make_ConstantTensorDescriptor_packed(Sequence<EPerBlock, BPerBlock>{});
// input blockwise copy // input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
auto blockwise_in_copy = auto blockwise_in_copy =
BlockwiseGenericTensorSliceCopy_v4<BlockSize, BlockwiseGenericTensorSliceCopy_v2_deprecated<BlockSize,
decltype(in_e_b_global_desc), decltype(in_e_b_global_desc),
decltype(in_e_b_block_desc), decltype(in_e_b_block_desc),
decltype(in_e_b_block_desc.GetLengths()), decltype(in_e_b_block_desc.GetLengths()),
...@@ -157,13 +149,13 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf ...@@ -157,13 +149,13 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
{0, b_block_data_on_global}, {0, 0}); {0, b_block_data_on_global}, {0, 0});
// weight tensor // weight tensor
// global mem // tensor descriptor in device memory, src of blockwise copy
constexpr auto wei_e_k_global_desc = reorder_tensor_descriptor_given_upper2lower( constexpr auto wei_e_k_global_desc =
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3), Sequence<1, 0>{}); wei_k_c_y_x_global_desc.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{});
// LDS // tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto wei_e_k_block_desc = make_native_tensor_descriptor_aligned( constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<EPerBlock, KPerBlock>{}, Sequence<EPerBlock, KPerBlock>{},
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{}); Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
...@@ -173,9 +165,11 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf ...@@ -173,9 +165,11 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
static_assert(wei_e_k_block_desc.GetStride(I0) % GemmDataPerReadA == 0, static_assert(wei_e_k_block_desc.GetStride(I0) % GemmDataPerReadA == 0,
"GemmDataPerReadA alignment requirement is not satisfied"); "GemmDataPerReadA alignment requirement is not satisfied");
// weight blockwise copy // operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto blockwise_wei_copy = auto blockwise_wei_copy =
BlockwiseGenericTensorSliceCopy_v4<BlockSize, BlockwiseGenericTensorSliceCopy_v2_deprecated<BlockSize,
decltype(wei_e_k_global_desc), decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc), decltype(wei_e_k_block_desc),
decltype(wei_e_k_block_desc.GetLengths()), decltype(wei_e_k_block_desc.GetLengths()),
...@@ -253,12 +247,14 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf ...@@ -253,12 +247,14 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
// zero out threadwise output // zero out threadwise output
threadwise_matrix_set_zero(c_k0k1_b0b1_thread_mtx_desc, p_out_thread); threadwise_matrix_set_zero(c_k0k1_b0b1_thread_mtx_desc, p_out_thread);
const Float* p_wei_block_on_global = p_wei_global;
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
blockwise_in_copy.template Run<Float, Float, address_space_t::global>( blockwise_in_copy.template Run<Float, AddressSpace::global>(p_in_global,
p_in_global, p_in_block_double); p_in_block_double);
blockwise_wei_copy.template Run<Float, Float, address_space_t::global>( blockwise_wei_copy.template Run<Float, AddressSpace::global>(p_wei_global,
p_wei_global, p_wei_block_double); p_wei_block_double);
} }
// LDS double buffer: main body // LDS double buffer: main body
...@@ -289,11 +285,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf ...@@ -289,11 +285,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy blockwise_in_copy.template RunLoadThreadBuffer<Float, AddressSpace::global>(
.template RunLoadThreadBuffer<Float, Float, address_space_t::global>(
p_in_global, p_in_thread_buffer); p_in_global, p_in_thread_buffer);
blockwise_wei_copy blockwise_wei_copy.template RunLoadThreadBuffer<Float, AddressSpace::global>(
.template RunLoadThreadBuffer<Float, Float, address_space_t::global>(
p_wei_global, p_wei_thread_buffer); p_wei_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
...@@ -317,9 +311,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf ...@@ -317,9 +311,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.template RunLoadThreadBuffer<Float, Float, address_space_t::global>( blockwise_in_copy.template RunLoadThreadBuffer<Float, AddressSpace::global>(
p_in_global, p_in_thread_buffer); p_in_global, p_in_thread_buffer);
blockwise_wei_copy.template RunLoadThreadBuffer<Float, Float, address_space_t::global>( blockwise_wei_copy.template RunLoadThreadBuffer<Float, AddressSpace::global>(
p_wei_global, p_wei_thread_buffer); p_wei_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
...@@ -342,6 +336,15 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf ...@@ -342,6 +336,15 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
// copy output: register to global memory // copy output: register to global memory
{ {
constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster;
constexpr index_t B1 = GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster;
// define tensor descriptor for threadwise copy
// output global descriptor, for calculating origin of thread tensor
// in global memory
constexpr auto out_k_b_global_desc = make_ConstantMergedTensorDescriptor(
out_n_k_h_w_global_desc, Sequence<1>{}, Sequence<0, 2, 3>{});
// calculate origin of thread output tensor on global memory // calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index // blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = const auto c_thread_mtx_on_block =
...@@ -353,51 +356,46 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf ...@@ -353,51 +356,46 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
const index_t b_thread_data_on_global = const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col; b_block_data_on_global + c_thread_mtx_on_block.col;
// src descriptor // This is a hack, because slicing a merged dimension is not supported yet.
constexpr auto out_k0_k1_b0_b1_thread_desc = make_native_tensor_descriptor_packed( // This should be replaced with logic above, once slicing a merged dimension support
Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNRepeat, GemmNPerThreadSubC>{}); // become available
// dst descriptor // dst descriptor
constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster; constexpr auto out_k0_k1_b_global_desc =
constexpr index_t B1 = GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster; make_ConstantMergedTensorDescriptor(out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}),
Sequence<1>{},
Sequence<2>{},
Sequence<0, 3, 4>{});
constexpr index_t K0 = K / K1; // src descriptor
constexpr index_t B0 = B / B1; constexpr auto out_k0_k1_b_thread_desc = make_ConstantTensorDescriptor_packed(
Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNRepeat * GemmNPerThreadSubC>{});
constexpr auto out_k_b_global_desc = transform_tensor_descriptor(
out_n_k_ho_wo_global_desc, using OutThreadCopySliceLengths =
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho, Wo>>{}), Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNPerThreadSubC>;
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); auto threadwise_out_copy = ThreadwiseGenericTensorSliceCopy_v2r1_deprecated<
decltype(out_k0_k1_b_thread_desc),
constexpr auto out_k0_k1_b0_b1_global_desc = transform_tensor_descriptor( decltype(out_k0_k1_b_global_desc),
out_k_b_global_desc, OutThreadCopySliceLengths,
make_tuple(UnMerge<Sequence<K0, K1>>{}, UnMerge<Sequence<B0, B1>>{}), arithmetic_sequence_gen<0, 3, 1>::type,
make_tuple(Sequence<0>{}, Sequence<1>{}), arithmetic_sequence_gen<0, 3, 1>::type,
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); 2,
2,
// output threadwise copy
ThreadwiseGenericTensorSliceCopy_v4r2<
decltype(out_k0_k1_b0_b1_thread_desc),
decltype(out_k0_k1_b0_b1_global_desc),
decltype(out_k0_k1_b0_b1_thread_desc.GetLengths()),
arithmetic_sequence_gen<0, 4, 1>::type,
3,
OutThreadCopyDataPerAccess_B, OutThreadCopyDataPerAccess_B,
OutThreadCopyDataPerAccess_B>({0, 0, 0, 0}, OutThreadCopyDataPerAccess_B>({0, 0, 0},
{k_thread_data_on_global / K1, {k_thread_data_on_global / K1,
k_thread_data_on_global % K1, k_thread_data_on_global % K1,
b_thread_data_on_global / B1, b_thread_data_on_global});
b_thread_data_on_global % B1})
#if 1 for(index_t nrepeat = 0; nrepeat < GemmNRepeat; ++nrepeat)
.template Run<Float, Float, address_space_t::generic, address_space_t::global> {
#else // tweaking threadwise_out_copy
.template Run_optimized_dst_address_calculation<Float, .template Run<Float, AddressSpace::generic, AddressSpace::global>(p_out_thread,
Float, p_out_global);
address_space_t::generic,
address_space_t::global> threadwise_out_copy.MoveSrcSliceWindow(Sequence<0, 0, GemmNPerThreadSubC>{}, True);
#endif threadwise_out_copy.MoveDstSliceWindow(Sequence<0, 0, B1>{}, True);
(p_out_thread, p_out_global); }
} }
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment