"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "d66421fe34f2b69de7fe53876a7eb5dea4f3fd9f"
Commit 40836ab9 authored by Chao Liu's avatar Chao Liu
Browse files

add back some code

parent 8bdaba51
...@@ -7,6 +7,10 @@ ...@@ -7,6 +7,10 @@
#include "blockwise_generic_tensor_slice_copy.hpp" #include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp" #include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_batched_gemm.hpp" #include "blockwise_batched_gemm.hpp"
#include "blockwise_2d_tensor_op.hpp"
#include "blockwise_4d_tensor_op.hpp"
#include "threadwise_tensor_slice_copy.hpp"
#include "threadwise_4d_tensor_op.hpp"
namespace ck { namespace ck {
...@@ -129,10 +133,20 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer ...@@ -129,10 +133,20 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor_packed( constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor_packed(
Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{}); Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
#if 1
// blockwise copy // blockwise copy
// input: format is [C, Hi, Wi, N] // input: format is [C, Hi, Wi, N]
const auto blockwise_in_copy =
Blockwise4dTensorCopy3<BlockSize,
Float,
decltype(in_c_h_w_n_global_desc),
decltype(in_c_h_w_n_block_desc),
decltype(in_c_h_w_n_block_desc.GetLengths()),
InBlockCopyClusterLengths_CHWN,
InBlockCopyDataPerAccess_N>{};
#else
auto blockwise_in_copy = auto blockwise_in_copy =
BlockwiseGenericTensorSliceCopy_v2<BlockSize, BlockwiseGenericTensorSliceCopy_v1<BlockSize,
decltype(in_c_h_w_n_global_desc), decltype(in_c_h_w_n_global_desc),
decltype(in_c_h_w_n_block_desc), decltype(in_c_h_w_n_block_desc),
decltype(in_c_h_w_n_block_desc.GetLengths()), decltype(in_c_h_w_n_block_desc.GetLengths()),
...@@ -146,11 +160,21 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer ...@@ -146,11 +160,21 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
InBlockCopyDataPerAccess_N, InBlockCopyDataPerAccess_N,
InBlockCopyDataPerAccess_N>({0, 0, 0, 0}, InBlockCopyDataPerAccess_N>({0, 0, 0, 0},
{0, 0, 0, 0}); {0, 0, 0, 0});
#endif
#if 1
// blockwise wei copy // blockwise wei copy
// format is [CPerBlock, X * KPerBlock] // format is [CPerBlock, X * KPerBlock]
const auto blockwise_wei_copy = const auto blockwise_wei_copy =
BlockwiseGenericTensorSliceCopy_v2<BlockSize, Blockwise2dTensorCopy3<BlockSize,
Float,
decltype(wei_c_k_global_desc),
decltype(wei_c_k_block_desc),
decltype(wei_c_k_block_desc.GetLengths()),
WeiBlockCopyDataPerAccess_K>({0, 0}, {0, 0});
#else
const auto blockwise_wei_copy =
BlockwiseGenericTensorSliceCopy_v1<BlockSize,
decltype(wei_c_k_global_desc), decltype(wei_c_k_global_desc),
decltype(wei_c_k_block_desc), decltype(wei_c_k_block_desc),
decltype(wei_c_k_block_desc.GetLengths()), decltype(wei_c_k_block_desc.GetLengths()),
...@@ -163,6 +187,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer ...@@ -163,6 +187,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
1, 1,
WeiBlockCopyDataPerAccess_K, WeiBlockCopyDataPerAccess_K,
WeiBlockCopyDataPerAccess_K>({0, 0}, {0, 0}); WeiBlockCopyDataPerAccess_K>({0, 0}, {0, 0});
#endif
// a series of blockwise batched GEMM // a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix // C_matrix += transpose(A_matrix) * B_matrix
...@@ -402,7 +427,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer ...@@ -402,7 +427,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
wo_block_data_begin + wo_thread_data_begin, wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin); n_block_data_begin + n_thread_data_begin);
ThreadwiseGenericTensorSliceCopy_v2r1<decltype(out_10d_thread_desc), #if 1
threadwise_tensor_slice_copy(out_10d_thread_desc,
p_out_thread,
out_10d_global_desc,
p_out_thread_on_global,
out_10d_thread_desc.GetLengths(),
Number<OutThreadCopyDataPerAccess_N>{});
#else
ThreadwiseGenericTensorSliceCopy_v1r1<decltype(out_10d_thread_desc),
decltype(out_10d_global_desc), decltype(out_10d_global_desc),
decltype(out_10d_thread_desc.GetLengths()), decltype(out_10d_thread_desc.GetLengths()),
arithmetic_sequence_gen<0, 10, 1>::type, arithmetic_sequence_gen<0, 10, 1>::type,
...@@ -413,6 +446,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer ...@@ -413,6 +446,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
OutThreadCopyDataPerAccess_N>( OutThreadCopyDataPerAccess_N>(
make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>()) make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>())
.Run(p_out_thread, p_out_thread_on_global); .Run(p_out_thread, p_out_thread_on_global);
#endif
}).Else([&](auto fwd) { }).Else([&](auto fwd) {
static_assert(fwd(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock && static_assert(fwd(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
GemmNPerThreadSubC % NPerThread == 0, GemmNPerThreadSubC % NPerThread == 0,
...@@ -460,7 +494,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer ...@@ -460,7 +494,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
wo_block_data_begin + wo_thread_data_begin, wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin); n_block_data_begin + n_thread_data_begin);
ThreadwiseGenericTensorSliceCopy_v2r1<decltype(out_10d_thread_desc), #if 1
threadwise_tensor_slice_copy(out_10d_thread_desc,
p_out_thread,
out_10d_global_desc,
p_out_thread_on_global,
out_10d_thread_desc.GetLengths(),
Number<OutThreadCopyDataPerAccess_N>{});
#else
ThreadwiseGenericTensorSliceCopy_v1r1<decltype(out_10d_thread_desc),
decltype(out_10d_global_desc), decltype(out_10d_global_desc),
decltype(out_10d_thread_desc.GetLengths()), decltype(out_10d_thread_desc.GetLengths()),
arithmetic_sequence_gen<0, 10, 1>::type, arithmetic_sequence_gen<0, 10, 1>::type,
...@@ -471,6 +513,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer ...@@ -471,6 +513,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
OutThreadCopyDataPerAccess_N>( OutThreadCopyDataPerAccess_N>(
make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>()) make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>())
.Run(p_out_thread, p_out_thread_on_global); .Run(p_out_thread, p_out_thread_on_global);
#endif
}); });
} }
}; };
......
...@@ -143,7 +143,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, ...@@ -143,7 +143,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t WeiBlockCopyDataPerAccess_K = 4; constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
constexpr index_t OutThreadCopyDataPerAccess_N = 2; constexpr index_t OutThreadCopyDataPerAccess_N = 2;
#elif 1 #elif 0
// for 3x3, 34x34, v1r3, Pascal // for 3x3, 34x34, v1r3, Pascal
// for 3x3, 28x28, v1r3, Pascal // for 3x3, 28x28, v1r3, Pascal
// for 3x3, 14x14, v1r3, Pascal // for 3x3, 14x14, v1r3, Pascal
...@@ -266,9 +266,12 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, ...@@ -266,9 +266,12 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmDataPerReadA = 4; constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4; constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_CHWN = Sequence<1, 1, 1, 4>;
using InBlockCopyClusterLengths_CHWN = Sequence<8, 2, 4, 4>; using InBlockCopyClusterLengths_CHWN = Sequence<8, 2, 4, 4>;
constexpr index_t InBlockCopyDataPerAccess_N = 4; constexpr index_t InBlockCopyDataPerAccess_N = 4;
using WeiBlockCopySubLengths_CK = Sequence<1, 4>;
using WeiBlockCopyClusterLengths_CK = Sequence<8, 32>;
constexpr index_t WeiBlockCopyDataPerAccess_K = 4; constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
constexpr index_t OutThreadCopyDataPerAccess_N = 4; constexpr index_t OutThreadCopyDataPerAccess_N = 4;
......
...@@ -71,7 +71,7 @@ int main(int argc, char* argv[]) ...@@ -71,7 +71,7 @@ int main(int argc, char* argv[])
{ {
using namespace ck; using namespace ck;
#if 1 #if 0
constexpr index_t N = 64; constexpr index_t N = 64;
constexpr index_t C = 1536; constexpr index_t C = 1536;
constexpr index_t HI = 8; constexpr index_t HI = 8;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment