Commit b2b622e8 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent a65ef903
...@@ -87,9 +87,6 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc, ...@@ -87,9 +87,6 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
constexpr unsigned HoPerThread = 1; constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1; constexpr unsigned WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32;
constexpr unsigned InBlockCopy_ThreadPerDimC = 4; constexpr unsigned InBlockCopy_ThreadPerDimC = 4;
constexpr unsigned InBlockCopy_ThreadPerDimH = 4; constexpr unsigned InBlockCopy_ThreadPerDimH = 4;
constexpr unsigned InBlockCopy_ThreadPerDimW = 2; constexpr unsigned InBlockCopy_ThreadPerDimW = 2;
...@@ -278,8 +275,6 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc, ...@@ -278,8 +275,6 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
KPerThread, KPerThread,
HoPerThread, HoPerThread,
WoPerThread, WoPerThread,
WeiBlockCopyThreadPerDim0,
WeiBlockCopyThreadPerDim1,
Sequence<InBlockCopy_ThreadPerDimC, Sequence<InBlockCopy_ThreadPerDimC,
InBlockCopy_ThreadPerDimH, InBlockCopy_ThreadPerDimH,
InBlockCopy_ThreadPerDimW, InBlockCopy_ThreadPerDimW,
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <numeric> #include <numeric>
#include <initializer_list> #include <initializer_list>
#include <cstdlib> #include <cstdlib>
#include <stdlib.h>
#include "config.h" #include "config.h"
#include "tensor.hpp" #include "tensor.hpp"
#include "ConstantTensorDescriptor.hip.hpp" #include "ConstantTensorDescriptor.hip.hpp"
...@@ -378,7 +379,7 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result) ...@@ -378,7 +379,7 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
std::cout << "max_diff: " << max_diff << ", " << ref_value << ", " << result_value << std::endl; std::cout << "max_diff: " << max_diff << ", " << ref_value << ", " << result_value << std::endl;
} }
int main() int main(int argc, char* argv[])
{ {
#if 0 #if 0
constexpr unsigned N = 1; constexpr unsigned N = 1;
...@@ -571,7 +572,14 @@ int main() ...@@ -571,7 +572,14 @@ int main()
std::size_t num_thread = std::thread::hardware_concurrency(); std::size_t num_thread = std::thread::hardware_concurrency();
bool do_verification = true; if(argc != 3)
{
printf("arg1: do_verification, arg2: nrepeat\n");
exit(1);
}
bool do_verification = atoi(argv[1]);
unsigned nrepeat = atoi(argv[2]);
if(do_verification) if(do_verification)
{ {
...@@ -587,8 +595,6 @@ int main() ...@@ -587,8 +595,6 @@ int main()
#endif #endif
} }
unsigned nrepeat = 200;
#if 1 #if 1
#if 0 #if 0
device_direct_convolution_1 device_direct_convolution_1
......
...@@ -23,8 +23,6 @@ template <unsigned GridSize, ...@@ -23,8 +23,6 @@ template <unsigned GridSize,
unsigned KPerThread, unsigned KPerThread,
unsigned HoPerThread, unsigned HoPerThread,
unsigned WoPerThread, unsigned WoPerThread,
unsigned WeiBlockCopyThreadPerDim0,
unsigned WeiBlockCopyThreadPerDim1,
class InBlockCopyThreadPerDims, class InBlockCopyThreadPerDims,
unsigned InBlockCopyDataPerRead, unsigned InBlockCopyDataPerRead,
unsigned WeiBlockCopyDataPerRead, unsigned WeiBlockCopyDataPerRead,
...@@ -109,69 +107,31 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric ...@@ -109,69 +107,31 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
constexpr auto out_khwn_thread_desc = constexpr auto out_khwn_thread_desc =
make_ConstantTensorDescriptor(Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{}); make_ConstantTensorDescriptor(Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(in_nchw_block_desc, "in_nchw_block_desc");
print_ConstantTensorDescriptor(in_chwn_block_desc, "in_chwn_block_desc");
print_ConstantTensorDescriptor(wei_srck_block_desc, "wei_srck_block_desc");
print_ConstantTensorDescriptor(out_hkwn_thread_desc, "out_hkwn_thread_desc");
}
#endif
// blockwise copy // blockwise copy
// input: format is [C, Hi, Wi, N] // input: format is [C, Hi, Wi, N]
#if 0 const auto blockwise_in_copy = Blockwise4dTensorCopy3<BlockSize,
constexpr auto blockwise_in_copy =
Blockwise4dTensorCopy1<BlockSize,
Float,
decltype(in_chwn_global_desc),
decltype(in_chwn_block_desc),
decltype(in_chwn_block_desc.GetLengths())>{};
#elif 1
const auto blockwise_in_copy = Blockwise4dTensorCopy3<BlockSize,
Float, Float,
decltype(in_chwn_global_desc), decltype(in_chwn_global_desc),
decltype(in_chwn_block_desc), decltype(in_chwn_block_desc),
decltype(in_chwn_block_desc.GetLengths()), decltype(in_chwn_block_desc.GetLengths()),
InBlockCopyThreadPerDims, InBlockCopyThreadPerDims,
InBlockCopyDataPerRead>{}; InBlockCopyDataPerRead>{};
#endif
// blockwise wei copy // blockwise wei copy
// format is [CPerBlock*S*R,KPerBlock] // format is [CPerBlock*S*R,KPerBlock]
#if 0
const auto blockwise_wei_copy =
Blockwise2dTensorCopy1<BlockSize,
Float,
decltype(wei_ek_global_desc),
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc.GetLengths())>{};
#elif 0
const auto blockwise_wei_copy = Blockwise2dTensorCopy2<BlockSize,
Float,
decltype(wei_ek_global_desc),
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc.GetLengths()),
WeiBlockCopyThreadPerDim0,
WeiBlockCopyThreadPerDim1>{};
#elif 1
const auto blockwise_wei_copy = Blockwise2dTensorCopy3<BlockSize, const auto blockwise_wei_copy = Blockwise2dTensorCopy3<BlockSize,
Float, Float,
decltype(wei_ek_global_desc), decltype(wei_ek_global_desc),
decltype(wei_ek_block_desc), decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc.GetLengths()), decltype(wei_ek_block_desc.GetLengths()),
WeiBlockCopyDataPerRead>{}; WeiBlockCopyDataPerRead>{};
#endif
// a series of blockwise batched GEMM // a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix // C_matrix += transpose(A_matrix) * B_matrix
// A_matrix and B_matrix saved in LDS, C_matrix saved in register // A_matrix and B_matrix saved in LDS, C_matrix saved in register
// A_matrix[C,K] is a sub-matrix of wei_block[S,R,C,K] // A_matrix[C,K] is a sub-matrix of wei_block[C,S,R,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N] // B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
// C_matrix[K,Wo*N] is a sub-matrix of out_block[Ho,K,Wo,N] // C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
constexpr auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor( constexpr auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_csrk_block_desc.GetStride(I0)>{}); Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_csrk_block_desc.GetStride(I0)>{});
...@@ -185,23 +145,6 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric ...@@ -185,23 +145,6 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
Number<WoPerThread * NPerThread>{}, Number<WoPerThread * NPerThread>{},
Number<out_khwn_thread_desc.GetStride(I1)>{}); Number<out_khwn_thread_desc.GetStride(I1)>{});
#if 0
const auto blockwise_batch_gemm =
Blockwise1dStridedBatchedGemmBlockABlockBThreadC<BlockSize,
decltype(a_cxk_block_mtx_desc),
decltype(b_cxwn_block_mtx_desc),
decltype(c_kxwn_thread_mtx_desc),
true,
false,
false,
0,
in_chwn_block_desc.GetStride(I1),
out_khwn_thread_desc.GetStride(I1),
HoPerBlock,
HoPerThread,
GemmKPerThreadLoop,
true>{};
#else
const auto blockwise_batch_gemm = BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2< const auto blockwise_batch_gemm = BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2<
BlockSize, BlockSize,
decltype(a_cxk_block_mtx_desc), decltype(a_cxk_block_mtx_desc),
...@@ -219,7 +162,6 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric ...@@ -219,7 +162,6 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
GemmNLevel1Cluster, GemmNLevel1Cluster,
GemmKPerThreadLoop, GemmKPerThreadLoop,
HoPerThread>{}; HoPerThread>{};
#endif
// LDS: be careful of alignment // LDS: be careful of alignment
constexpr unsigned in_block_size = in_chwn_block_desc.GetElementSpace(); constexpr unsigned in_block_size = in_chwn_block_desc.GetElementSpace();
...@@ -277,26 +219,6 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric ...@@ -277,26 +219,6 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
const auto c_thread_mtx_begin = const auto c_thread_mtx_begin =
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
// for v1 batch-gemm
const unsigned k_thread_data_begin = c_thread_mtx_begin.row;
const unsigned ho_thread_data_begin = c_thread_mtx_begin.batch;
const unsigned wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
const unsigned n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
threadwise_4d_tensor_copy_v2(
out_khwn_thread_desc,
p_out_thread,
out_khwn_global_desc,
p_out_global + out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin),
out_khwn_thread_desc.GetLengths(),
Number<OutThreadCopyDataPerWrite>{});
#elif 0
const auto c_thread_mtx_begin =
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
for(unsigned k = 0; k < out_khwn_thread_desc.GetLength(I0); ++k) for(unsigned k = 0; k < out_khwn_thread_desc.GetLength(I0); ++k)
{ {
for(unsigned ho = 0; ho < out_khwn_thread_desc.GetLength(I1); ++ho) for(unsigned ho = 0; ho < out_khwn_thread_desc.GetLength(I1); ++ho)
...@@ -334,7 +256,7 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric ...@@ -334,7 +256,7 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
const unsigned k_thread_data_begin = c_thread_mtx_begin.row; const unsigned k_thread_data_begin = c_thread_mtx_begin.row;
const unsigned ho_thread_data_begin = c_thread_mtx_begin.batch; const unsigned ho_thread_data_begin = c_thread_mtx_begin.batch;
const unsigned wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock; const unsigned wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
const unsigned n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock; const unsigned n_thread_data_begin = c_thread_mtx_begin.col - NPerBlock * wo_thread_data_begin;
// this is for v2 GEMM // this is for v2 GEMM
// output is a 8d tensor // output is a 8d tensor
...@@ -375,6 +297,8 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric ...@@ -375,6 +297,8 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
} }
else if(NPerThread == NPerBlock) else if(NPerThread == NPerBlock)
{ {
// not implemented yet
assert(false);
} }
else else
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment