Commit 2b52fbd2 authored by Chao Liu's avatar Chao Liu
Browse files

bug fix

parent ff7a6219
...@@ -94,7 +94,7 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc, ...@@ -94,7 +94,7 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
constexpr auto out_hkwn_thread_desc = constexpr auto out_hkwn_thread_desc =
make_ConstantTensorDescriptor(Sequence<HoPerThread, KPerThread, WoPerThread, NPerThread>{}); make_ConstantTensorDescriptor(Sequence<HoPerThread, KPerThread, WoPerThread, NPerThread>{});
#if 1 #if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{ {
print_ConstantTensorDescriptor(in_nchw_block_desc, "in_nchw_block_desc"); print_ConstantTensorDescriptor(in_nchw_block_desc, "in_nchw_block_desc");
...@@ -156,7 +156,6 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc, ...@@ -156,7 +156,6 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
for(unsigned c_block_data_begin = 0; c_block_data_begin < in_nchw_global_desc.GetLength(I1); for(unsigned c_block_data_begin = 0; c_block_data_begin < in_nchw_global_desc.GetLength(I1);
c_block_data_begin += CPerBlock, __syncthreads()) c_block_data_begin += CPerBlock, __syncthreads())
{ {
#if 1
// input: global mem to LDS, // input: global mem to LDS,
// convert 4d-tensor in[N,C,Hi,Wi] to matrix in_matrix[C,Hi*Wi*N] // convert 4d-tensor in[N,C,Hi,Wi] to matrix in_matrix[C,Hi*Wi*N]
blockwise_4d_tensor_copy_reorder_by_get_dst_from_src<BlockSize>( blockwise_4d_tensor_copy_reorder_by_get_dst_from_src<BlockSize>(
...@@ -169,9 +168,7 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc, ...@@ -169,9 +168,7 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
p_in_block, p_in_block,
in_nchw_block_desc.GetLengths(), in_nchw_block_desc.GetLengths(),
reorder_chwn_from_nchw); reorder_chwn_from_nchw);
#endif
#if 1
// weight: global mem to LDS, // weight: global mem to LDS,
blockwise_4d_tensor_copy_reorder_by_get_dst_from_src<BlockSize>( blockwise_4d_tensor_copy_reorder_by_get_dst_from_src<BlockSize>(
wei_kcsr_global_desc, wei_kcsr_global_desc,
...@@ -181,11 +178,9 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc, ...@@ -181,11 +178,9 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
p_wei_block, p_wei_block,
wei_kcsr_block_desc.GetLengths(), wei_kcsr_block_desc.GetLengths(),
reorder_srck_from_kcsr); reorder_srck_from_kcsr);
#endif
__syncthreads(); __syncthreads();
#if 1
// a series of batched GEMM // a series of batched GEMM
for(unsigned s = 0; s < S; ++s) for(unsigned s = 0; s < S; ++s)
{ {
...@@ -194,12 +189,11 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc, ...@@ -194,12 +189,11 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
auto f_accum = [](auto& c, const auto&& ab) { c += ab; }; auto f_accum = [](auto& c, const auto&& ab) { c += ab; };
blockwise_batch_gemm.run(p_wei_block + wei_srck_block_desc.Get1dIndex(s, r, 0, 0), blockwise_batch_gemm.run(p_wei_block + wei_srck_block_desc.Get1dIndex(s, r, 0, 0),
p_in_block + in_chwn_block_desc.Get1dIndex(0, 0, r, 0), p_in_block + in_chwn_block_desc.Get1dIndex(0, s, r, 0),
p_out_thread, p_out_thread,
f_accum); f_accum);
} }
} }
#endif
} }
const auto matrix_c_index = const auto matrix_c_index =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment