Commit 52423948 authored by Jehandad Khan's avatar Jehandad Khan
Browse files

Merge branch 'master' into jd_redux

parents b97af4ec 98a2cfcc
......@@ -241,16 +241,15 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block);
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block);
#else
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset,
p_in_register_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
p_wei_register_clipboard);
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global_block_offset, p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset,
p_wei_register_buffer);
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, p_wei_block);
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block);
#endif
__syncthreads();
......
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_2d_tensor_op.hpp"
#include "blockwise_4d_tensor_op.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_tensor_slice_copy.hpp"
#include "threadwise_4d_tensor_op.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_batched_gemm.hpp"
namespace ck {
......@@ -37,10 +34,13 @@ template <index_t GridSize,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
class InBlockCopySubLengths_CHWN,
class InBlockCopyClusterLengths_CHWN,
index_t InBlockCopyDataPerRead_N,
index_t WeiBlockCopyDataPerRead_K,
index_t OutThreadCopyDataPerWrite_N>
index_t InBlockCopyDataPerAccess_N,
class WeiBlockCopySubLengths_CK,
class WeiBlockCopyClusterLengths_CK,
index_t WeiBlockCopyDataPerAccess_K,
index_t OutThreadCopyDataPerAccess_N>
struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
{
__device__ void Run(const Float* const __restrict__ p_in_global,
......@@ -79,21 +79,21 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
"wrong! cannot evenly divide work for workgroup ");
constexpr index_t NBlockWork = math::integer_divide_ceil(N, NPerBlock);
constexpr index_t KBlockWork = math::integer_divide_ceil(K, KPerBlock);
constexpr index_t HBlockWork = math::integer_divide_ceil(Ho, HoPerBlock);
constexpr index_t WBlockWork = math::integer_divide_ceil(Wo, WoPerBlock);
constexpr index_t NBlockWork = math::integer_divide_ceil(N, NPerBlock);
constexpr auto block_work_desc = make_ConstantTensorDescriptor_packed(
Sequence<NBlockWork, KBlockWork, HBlockWork, WBlockWork>{});
Sequence<KBlockWork, HBlockWork, WBlockWork, NBlockWork>{});
const auto block_work_multi_id =
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
const index_t n_block_data_begin = block_work_multi_id[0] * NPerBlock;
const index_t k_block_data_begin = block_work_multi_id[1] * KPerBlock;
const index_t ho_block_data_begin = block_work_multi_id[2] * HoPerBlock;
const index_t wo_block_data_begin = block_work_multi_id[3] * WoPerBlock;
const index_t k_block_data_begin = block_work_multi_id[0] * KPerBlock;
const index_t ho_block_data_begin = block_work_multi_id[1] * HoPerBlock;
const index_t wo_block_data_begin = block_work_multi_id[2] * WoPerBlock;
const index_t n_block_data_begin = block_work_multi_id[3] * NPerBlock;
const index_t hi_block_data_begin = ho_block_data_begin;
const index_t wi_block_data_begin = wo_block_data_begin;
......@@ -103,8 +103,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
// LDS tensor view
// be careful of alignment
constexpr index_t max_align = math::lcm(InBlockCopyDataPerRead_N,
WeiBlockCopyDataPerRead_K,
constexpr index_t max_align = math::lcm(InBlockCopyDataPerAccess_N,
WeiBlockCopyDataPerAccess_K,
GemmDataPerReadA,
GemmDataPerReadB);
......@@ -123,24 +123,10 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor_packed(
Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
// blockwise copy
// input: format is [C, Hi, Wi, N]
#if 0
const auto blockwise_in_copy =
Blockwise4dTensorCopy3<BlockSize,
Float,
decltype(in_c_h_w_n_global_desc),
decltype(in_c_h_w_n_block_desc),
decltype(in_c_h_w_n_block_desc.GetLengths()),
InBlockCopyClusterLengths_CHWN,
InBlockCopyDataPerRead_N>{};
#elif 0
using InBlockCopySubLengths_CHWN =
decltype(in_c_h_w_n_block_desc.GetLengths() / InBlockCopyClusterLengths_CHWN{});
// blockwise copy
// input: format is [C, Hi, Wi, N]
auto blockwise_in_copy =
BlockwiseGenericTensorSliceCopy_v1<BlockSize,
Float,
decltype(in_c_h_w_n_global_desc),
decltype(in_c_h_w_n_block_desc),
decltype(in_c_h_w_n_block_desc.GetLengths()),
......@@ -149,33 +135,28 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
Sequence<0, 1, 2, 3>,
Sequence<0, 1, 2, 3>,
Sequence<0, 1, 2, 3>,
1,
1>({0, 0, 0, 0}, {0, 0, 0, 0});
#elif 1
using InBlockCopySubLengths_CHWN =
decltype(in_c_h_w_n_block_desc.GetLengths() / InBlockCopyClusterLengths_CHWN{});
auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v2<
BlockSize,
decltype(in_c_h_w_n_global_desc),
decltype(in_c_h_w_n_block_desc),
NormalTensorCoordinate<decltype(in_c_h_w_n_global_desc)>,
NormalTensorCoordinate<decltype(in_c_h_w_n_block_desc)>,
decltype(in_c_h_w_n_block_desc.GetLengths()),
InBlockCopySubLengths_CHWN,
InBlockCopyClusterLengths_CHWN,
Sequence<0, 1, 2, 3>>({0, 0, 0, 0}, {0, 0, 0, 0});
#endif
3,
3,
InBlockCopyDataPerAccess_N,
InBlockCopyDataPerAccess_N>({0, 0, 0, 0},
{0, 0, 0, 0});
// blockwise wei copy
// format is [CPerBlock, X * KPerBlock]
const auto blockwise_wei_copy =
Blockwise2dTensorCopy3<BlockSize,
Float,
decltype(wei_c_k_global_desc),
decltype(wei_c_k_block_desc),
decltype(wei_c_k_block_desc.GetLengths()),
WeiBlockCopyDataPerRead_K>({0, 0}, {0, 0});
BlockwiseGenericTensorSliceCopy_v1<BlockSize,
decltype(wei_c_k_global_desc),
decltype(wei_c_k_block_desc),
decltype(wei_c_k_block_desc.GetLengths()),
WeiBlockCopySubLengths_CK,
WeiBlockCopyClusterLengths_CK,
Sequence<0, 1>,
Sequence<0, 1>,
Sequence<0, 1>,
1,
1,
WeiBlockCopyDataPerAccess_K,
WeiBlockCopyDataPerAccess_K>({0, 0}, {0, 0});
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
......@@ -278,7 +259,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
}
}
// output: register to global mem,
// output: register to global mem
const auto c_thread_mtx_begin =
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
......@@ -329,17 +310,36 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
}
#endif
threadwise_tensor_slice_copy(out_10d_thread_desc,
p_out_thread,
out_10d_global_desc,
p_out_global +
out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin),
out_10d_thread_desc.GetLengths(),
Number<OutThreadCopyDataPerWrite_N>{});
Float* p_out_thread_on_global = p_out_global +
out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin);
#if 1
ThreadwiseGenericTensorSliceCopy_v1r2<decltype(out_10d_thread_desc),
decltype(out_10d_global_desc),
decltype(out_10d_thread_desc.GetLengths()),
arithmetic_sequence_gen<0, 10, 1>::type,
9,
OutThreadCopyDataPerAccess_N,
OutThreadCopyDataPerAccess_N>(
make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>())
.Run(p_out_thread, p_out_thread_on_global);
#elif 0
ThreadwiseGenericTensorSliceCopy_v1r1<decltype(out_10d_thread_desc),
decltype(out_10d_global_desc),
decltype(out_10d_thread_desc.GetLengths()),
arithmetic_sequence_gen<0, 10, 1>::type,
arithmetic_sequence_gen<0, 10, 1>::type,
9,
9,
OutThreadCopyDataPerAccess_N,
OutThreadCopyDataPerAccess_N>(
make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>())
.Run(p_out_thread, p_out_thread_on_global);
#endif
}).Else([&](auto fwd) {
static_assert(fwd(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
GemmNPerThreadSubC % NPerThread == 0,
......@@ -380,17 +380,36 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
}
#endif
threadwise_tensor_slice_copy(out_10d_thread_desc,
p_out_thread,
out_10d_global_desc,
p_out_global +
out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin),
out_10d_thread_desc.GetLengths(),
Number<OutThreadCopyDataPerWrite_N>{});
Float* p_out_thread_on_global = p_out_global +
out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin);
#if 1
ThreadwiseGenericTensorSliceCopy_v1r2<decltype(out_10d_thread_desc),
decltype(out_10d_global_desc),
decltype(out_10d_thread_desc.GetLengths()),
arithmetic_sequence_gen<0, 10, 1>::type,
9,
OutThreadCopyDataPerAccess_N,
OutThreadCopyDataPerAccess_N>(
make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>())
.Run(p_out_thread, p_out_thread_on_global);
#elif 0
ThreadwiseGenericTensorSliceCopy_v1r1<decltype(out_10d_thread_desc),
decltype(out_10d_global_desc),
decltype(out_10d_thread_desc.GetLengths()),
arithmetic_sequence_gen<0, 10, 1>::type,
arithmetic_sequence_gen<0, 10, 1>::type,
9,
9,
OutThreadCopyDataPerAccess_N,
OutThreadCopyDataPerAccess_N>(
make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>())
.Run(p_out_thread, p_out_thread_on_global);
#endif
});
}
};
......
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_2d_tensor_op.hpp"
#include "blockwise_4d_tensor_op.hpp"
#include "threadwise_tensor_slice_copy.hpp"
#include "threadwise_4d_tensor_op.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_batched_gemm.hpp"
namespace ck {
......@@ -36,10 +34,13 @@ template <index_t GridSize,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
class InBlockCopySubLengths_CHWN,
class InBlockCopyClusterLengths_CHWN,
index_t InBlockCopyDataPerRead_N,
index_t WeiBlockCopyDataPerRead_K,
index_t OutThreadCopyDataPerWrite_N>
index_t InBlockCopyDataPerAccess_N,
class WeiBlockCopySubLengths_CK,
class WeiBlockCopyClusterLengths_CK,
index_t WeiBlockCopyDataPerAccess_K,
index_t OutThreadCopyDataPerAccess_N>
struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
{
__device__ void Run(const Float* const __restrict__ p_in_global,
......@@ -73,14 +74,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1);
constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2);
constexpr index_t HiPerBlock = HoPerBlock + Y - 1;
constexpr index_t WiPerBlock = WoPerBlock + X - 1;
// assert for LDS double buffer
static_assert(C % (2 * CPerBlock) == 0, "C cannot be evenly divided");
// divide block work: [K, Ho, Wo, N]
static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 &&
static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % (2 * CPerBlock) == 0 &&
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
"wrong! cannot evenly divide work for workgroup ");
......@@ -108,8 +103,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
// LDS tensor view
// be careful of alignment
constexpr index_t max_align = math::lcm(InBlockCopyDataPerRead_N,
WeiBlockCopyDataPerRead_K,
constexpr index_t max_align = math::lcm(InBlockCopyDataPerAccess_N,
WeiBlockCopyDataPerAccess_K,
GemmDataPerReadA,
GemmDataPerReadB);
......@@ -130,24 +125,47 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
// blockwise copy
// input: format is [C, Hi, Wi, N]
const auto blockwise_in_copy =
Blockwise4dTensorCopy3<BlockSize,
Float,
decltype(in_c_h_w_n_global_desc),
decltype(in_c_h_w_n_block_desc),
decltype(in_c_h_w_n_block_desc.GetLengths()),
InBlockCopyClusterLengths_CHWN,
InBlockCopyDataPerRead_N>{};
auto blockwise_in_copy =
#if 0
BlockwiseGenericTensorSliceCopy_v1
#else
BlockwiseGenericTensorSliceCopy_v2
#endif
<BlockSize,
decltype(in_c_h_w_n_global_desc),
decltype(in_c_h_w_n_block_desc),
decltype(in_c_h_w_n_block_desc.GetLengths()),
InBlockCopySubLengths_CHWN,
InBlockCopyClusterLengths_CHWN,
Sequence<0, 1, 2, 3>,
Sequence<0, 1, 2, 3>,
Sequence<0, 1, 2, 3>,
3,
3,
InBlockCopyDataPerAccess_N,
InBlockCopyDataPerAccess_N>({0, 0, 0, 0}, {0, 0, 0, 0});
// blockwise wei copy
// format is [CPerBlock, X * KPerBlock]
const auto blockwise_wei_copy =
Blockwise2dTensorCopy3<BlockSize,
Float,
decltype(wei_c_k_global_desc),
decltype(wei_c_k_block_desc),
decltype(wei_c_k_block_desc.GetLengths()),
WeiBlockCopyDataPerRead_K>({0, 0}, {0, 0});
#if 0
BlockwiseGenericTensorSliceCopy_v1
#else
BlockwiseGenericTensorSliceCopy_v2
#endif
<BlockSize,
decltype(wei_c_k_global_desc),
decltype(wei_c_k_block_desc),
decltype(wei_c_k_block_desc.GetLengths()),
WeiBlockCopySubLengths_CK,
WeiBlockCopyClusterLengths_CK,
Sequence<0, 1>,
Sequence<0, 1>,
Sequence<0, 1>,
1,
1,
WeiBlockCopyDataPerAccess_K,
WeiBlockCopyDataPerAccess_K>({0, 0}, {0, 0});
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
......@@ -233,18 +251,18 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
// LDS double buffer: preload data into LDS
{
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset,
p_in_register_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
p_wei_register_clipboard);
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard,
p_in_block_double);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
p_wei_block_double);
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global_block_offset,
p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset,
p_wei_register_buffer);
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer,
p_in_block_double);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
p_wei_block_double);
}
// LDS double buffer: main body
......@@ -266,9 +284,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
Float
p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
p_in_global_block_offset +=
CPerBlock * in_c_h_w_n_global_desc.GetStride(I0);
......@@ -278,25 +295,25 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset,
p_in_register_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
p_wei_register_clipboard);
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global_block_offset,
p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset,
p_wei_register_buffer);
blockwise_batch_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard,
p_in_block_next);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
p_wei_block_next);
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer,
p_in_block_next);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
p_wei_block_next);
}
}
// LDS double buffer: tail
{
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
// even iteration
p_in_global_block_offset += CPerBlock * in_c_h_w_n_global_desc.GetStride(I0);
......@@ -305,19 +322,19 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset,
p_in_register_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
p_wei_register_clipboard);
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global_block_offset,
p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset,
p_wei_register_buffer);
// LDS double buffer: GEMM on current data
blockwise_batch_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard,
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterClipboard(
p_wei_register_clipboard, p_wei_block_double + wei_block_space);
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer,
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
p_wei_block_double + wei_block_space);
// odd iteration
__syncthreads();
......@@ -330,7 +347,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
}
}
// output: register to global mem,
// output: register to global mem
const auto c_thread_mtx_begin =
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
......@@ -381,17 +398,36 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
}
#endif
threadwise_tensor_slice_copy(out_10d_thread_desc,
p_out_thread,
out_10d_global_desc,
p_out_global +
out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin),
out_10d_thread_desc.GetLengths(),
Number<OutThreadCopyDataPerWrite_N>{});
Float* p_out_thread_on_global = p_out_global +
out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin);
#if 1
ThreadwiseGenericTensorSliceCopy_v1r2<decltype(out_10d_thread_desc),
decltype(out_10d_global_desc),
decltype(out_10d_thread_desc.GetLengths()),
arithmetic_sequence_gen<0, 10, 1>::type,
9,
OutThreadCopyDataPerAccess_N,
OutThreadCopyDataPerAccess_N>(
make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>())
.Run(p_out_thread, p_out_thread_on_global);
#elif 0
ThreadwiseGenericTensorSliceCopy_v1r1<decltype(out_10d_thread_desc),
decltype(out_10d_global_desc),
decltype(out_10d_thread_desc.GetLengths()),
arithmetic_sequence_gen<0, 10, 1>::type,
arithmetic_sequence_gen<0, 10, 1>::type,
9,
9,
OutThreadCopyDataPerAccess_N,
OutThreadCopyDataPerAccess_N>(
make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>())
.Run(p_out_thread, p_out_thread_on_global);
#endif
}).Else([&](auto fwd) {
static_assert(fwd(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
GemmNPerThreadSubC % NPerThread == 0,
......@@ -432,17 +468,36 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
}
#endif
threadwise_tensor_slice_copy(out_10d_thread_desc,
p_out_thread,
out_10d_global_desc,
p_out_global +
out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin),
out_10d_thread_desc.GetLengths(),
Number<OutThreadCopyDataPerWrite_N>{});
Float* p_out_thread_on_global = p_out_global +
out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin);
#if 1
ThreadwiseGenericTensorSliceCopy_v1r2<decltype(out_10d_thread_desc),
decltype(out_10d_global_desc),
decltype(out_10d_thread_desc.GetLengths()),
arithmetic_sequence_gen<0, 10, 1>::type,
9,
OutThreadCopyDataPerAccess_N,
OutThreadCopyDataPerAccess_N>(
make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>())
.Run(p_out_thread, p_out_thread_on_global);
#elif 0
ThreadwiseGenericTensorSliceCopy_v1r1<decltype(out_10d_thread_desc),
decltype(out_10d_global_desc),
decltype(out_10d_thread_desc.GetLengths()),
arithmetic_sequence_gen<0, 10, 1>::type,
arithmetic_sequence_gen<0, 10, 1>::type,
9,
9,
OutThreadCopyDataPerAccess_N,
OutThreadCopyDataPerAccess_N>(
make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>())
.Run(p_out_thread, p_out_thread_on_global);
#endif
});
}
};
......
......@@ -254,19 +254,18 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer
// LDS double buffer: preload data into LDS
{
Float p_in_register_clipboard[blockwise_in_copy_reorder
.GetRegisterClipboardSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
blockwise_in_copy_reorder.RunLoadRegisterClipboard(p_in_global_block_offset,
p_in_register_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
p_wei_register_clipboard);
blockwise_in_copy_reorder.RunStoreRegisterClipboard(p_in_register_clipboard,
p_in_block_double);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
p_wei_block_double);
Float p_in_register_buffer[blockwise_in_copy_reorder.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
blockwise_in_copy_reorder.RunLoadRegisterBuffer(p_in_global_block_offset,
p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset,
p_wei_register_buffer);
blockwise_in_copy_reorder.RunStoreRegisterBuffer(p_in_register_buffer,
p_in_block_double);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
p_wei_block_double);
}
// LDS double buffer: main body
......@@ -288,10 +287,9 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer
Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_register_clipboard[blockwise_in_copy_reorder
.GetRegisterClipboardSize()];
Float
p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
p_in_register_buffer[blockwise_in_copy_reorder.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
p_in_global_block_offset +=
CPerBlock * in_n_c_h_w_global_desc.GetStride(I1);
......@@ -301,27 +299,26 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy_reorder.RunLoadRegisterClipboard(p_in_global_block_offset,
p_in_register_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
p_wei_register_clipboard);
blockwise_in_copy_reorder.RunLoadRegisterBuffer(p_in_global_block_offset,
p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset,
p_wei_register_buffer);
// LDS double buffer: GEMM on current data
run_blockwise_batch_gemm(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy_reorder.RunStoreRegisterClipboard(p_in_register_clipboard,
p_in_block_next);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
p_wei_block_next);
blockwise_in_copy_reorder.RunStoreRegisterBuffer(p_in_register_buffer,
p_in_block_next);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
p_wei_block_next);
}
}
// LDS double buffer: tail
{
Float p_in_register_clipboard[blockwise_in_copy_reorder
.GetRegisterClipboardSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
Float p_in_register_buffer[blockwise_in_copy_reorder.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
// even iteration
p_in_global_block_offset += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1);
......@@ -330,19 +327,19 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy_reorder.RunLoadRegisterClipboard(p_in_global_block_offset,
p_in_register_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
p_wei_register_clipboard);
blockwise_in_copy_reorder.RunLoadRegisterBuffer(p_in_global_block_offset,
p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset,
p_wei_register_buffer);
// LDS double buffer: GEMM on current data
run_blockwise_batch_gemm(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy_reorder.RunStoreRegisterClipboard(
p_in_register_clipboard, p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterClipboard(
p_wei_register_clipboard, p_wei_block_double + wei_block_space);
blockwise_in_copy_reorder.RunStoreRegisterBuffer(
p_in_register_buffer, p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
p_wei_block_double + wei_block_space);
// odd iteration
__syncthreads();
......
......@@ -214,16 +214,15 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn
__syncthreads())
{
// load data
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset,
p_in_register_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
p_wei_register_clipboard);
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global_block_offset, p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset,
p_wei_register_buffer);
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, p_wei_block);
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block);
__syncthreads();
......
......@@ -209,17 +209,15 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
// preload data into LDS
{
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset,
p_in_register_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
p_wei_register_clipboard);
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global_block_offset, p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset,
p_wei_register_buffer);
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block_double);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
p_wei_block_double);
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_double);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_double);
}
// register
......@@ -247,18 +245,18 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
// load next data
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0);
p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0);
__syncthreads();
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset,
p_in_register_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
p_wei_register_clipboard);
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global_block_offset,
p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset,
p_wei_register_buffer);
// compute on current data
// a series of GEMM
......@@ -280,10 +278,8 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
}
}
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard,
p_in_block_next);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
p_wei_block_next);
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_next);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_next);
}
}
......@@ -295,14 +291,13 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
__syncthreads();
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset,
p_in_register_clipboard);
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global_block_offset, p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
p_wei_register_clipboard);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset,
p_wei_register_buffer);
for(index_t y = 0; y < Y; ++y)
{
......@@ -322,10 +317,10 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
}
}
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard,
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
p_wei_block_double + wei_block_space);
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer,
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
p_wei_block_double + wei_block_space);
// odd
__syncthreads();
......
......@@ -267,9 +267,8 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer
Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
Float
p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
p_in_block_on_global += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1);
p_wei_block_on_global += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0);
......@@ -277,26 +276,26 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterClipboard(p_in_block_on_global,
p_in_register_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_block_on_global,
p_wei_register_clipboard);
blockwise_in_copy.RunLoadRegisterBuffer(p_in_block_on_global,
p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global,
p_wei_register_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard,
p_in_block_next);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
p_wei_block_next);
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer,
p_in_block_next);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
p_wei_block_next);
}
}
// LDS double buffer: tail
{
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
// even iteration
p_in_block_on_global += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1);
......@@ -305,19 +304,19 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterClipboard(p_in_block_on_global,
p_in_register_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_block_on_global,
p_wei_register_clipboard);
blockwise_in_copy.RunLoadRegisterBuffer(p_in_block_on_global,
p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global,
p_wei_register_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard,
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterClipboard(
p_wei_register_clipboard, p_wei_block_double + wei_block_space);
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer,
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
p_wei_block_double + wei_block_space);
// odd iteration
__syncthreads();
......
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
......@@ -14,17 +14,16 @@ namespace ck {
// define B = merge(N0, Ho, Wo)
template <index_t GridSize,
index_t BlockSize,
class Float,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
class ConvStrides,
class ConvDilations,
typename Float,
typename InGlobalDesc,
typename WeiGlobalDesc,
typename OutGlobalDesc,
typename ConvStrides,
typename ConvDilations,
index_t BPerBlock,
index_t KPerBlock,
index_t EPerBlock,
index_t N1,
index_t N2,
index_t GemmNRepeat,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster,
......@@ -34,18 +33,18 @@ template <index_t GridSize,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
class InBlockCopySubLengths_E_N1_B_N2,
class InBlockCopyClusterLengths_E_N1_B_N2,
class InBlockCopyThreadClusterArrangeOrder,
class InBlockCopySrcAccessOrder,
class InBlockCopyDstAccessOrder,
typename InBlockCopySubLengths_E_N1_B_N2,
typename InBlockCopyClusterLengths_E_N1_B_N2,
typename InBlockCopyThreadClusterArrangeOrder,
typename InBlockCopySrcAccessOrder,
typename InBlockCopyDstAccessOrder,
index_t InBlockCopySrcDataPerRead_B,
index_t InBlockCopyDstDataPerWrite_N2,
class WeiBlockCopySubLengths_E_K,
class WeiBlockCopyClusterLengths_E_K,
class WeiBlockCopyThreadClusterArrangeOrder,
class WeiBlockCopySrcAccessOrder,
class WeiBlockCopyDstAccessOrder,
typename WeiBlockCopySubLengths_E_K,
typename WeiBlockCopyClusterLengths_E_K,
typename WeiBlockCopyThreadClusterArrangeOrder,
typename WeiBlockCopySrcAccessOrder,
typename WeiBlockCopyDstAccessOrder,
index_t WeiBlockCopySrcDataPerRead_E,
index_t WeiBlockCopyDstDataPerWrite_K>
struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
......@@ -56,7 +55,9 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
{
// this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters
static_assert(N2 == GemmNPerThreadSubC, "wrong!");
constexpr index_t N1 = GemmNRepeat;
constexpr index_t N2 = GemmNPerThreadSubC;
static_assert((N1 * N2 * BPerBlock) %
(GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) ==
0,
......@@ -99,10 +100,8 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
constexpr index_t E = C * Y * X;
// sanity-check for vectorized memory load
static_assert(ConvStrideW == 1 || InBlockCopySrcDataPerRead_B == 1,
"wrong! global vector load of input tensor is wrong");
static_assert((X == 1 || ConvDilationW % InBlockCopySrcDataPerRead_B == 0),
static_assert((Wo == 1 || (ConvStrideW == 1 || InBlockCopySrcDataPerRead_B == 1)) &&
(X == 1 || ConvDilationW % InBlockCopySrcDataPerRead_B == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
......@@ -155,13 +154,11 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
static_assert(in_e_n1_b_n2_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
"GemmDataPerReadB alignment requirement is not satisfied");
#if 0 // debug
// input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
auto blockwise_in_copy =
BlockwiseGenericTensorSliceCopy_v1<BlockSize,
Float,
BlockwiseGenericTensorSliceCopy_v2<BlockSize,
decltype(in_e_n1_b_n2_global_merged_desc),
decltype(in_e_n1_b_n2_block_desc),
decltype(in_e_n1_b_n2_block_desc.GetLengths()),
......@@ -170,21 +167,11 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder,
2,
3,
InBlockCopySrcDataPerRead_B,
InBlockCopyDstDataPerWrite_N2>(
{0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
#else
auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v2<
BlockSize,
decltype(in_e_n1_b_n2_global_merged_desc),
decltype(in_e_n1_b_n2_block_desc),
MergedTensorCoordinate<decltype(in_e_n1_b_n2_global_merged_desc)>,
NormalTensorCoordinate<decltype(in_e_n1_b_n2_block_desc)>,
decltype(in_e_n1_b_n2_block_desc.GetLengths()),
InBlockCopySubLengths_E_N1_B_N2,
InBlockCopyClusterLengths_E_N1_B_N2,
InBlockCopyThreadClusterArrangeOrder>({0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
#endif
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
......@@ -197,13 +184,17 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
Sequence<EPerBlock, KPerBlock>{},
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
#if 0 // debug
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert(wei_e_k_block_desc.GetStride(I0) % GemmDataPerReadA == 0,
"GemmDataPerReadA alignment requirement is not satisfied");
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto blockwise_wei_copy =
BlockwiseGenericTensorSliceCopy_v1<BlockSize,
Float,
BlockwiseGenericTensorSliceCopy_v2<BlockSize,
decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc),
decltype(wei_e_k_block_desc.GetLengths()),
......@@ -212,21 +203,11 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
0,
1,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>(
{0, k_block_data_on_global}, {0, 0});
#else
auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v2<
BlockSize,
decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc),
NormalTensorCoordinate<decltype(wei_e_k_global_desc)>,
NormalTensorCoordinate<decltype(wei_e_k_block_desc)>,
decltype(wei_e_k_block_desc.GetLengths()),
WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder>({0, k_block_data_on_global}, {0, 0});
#endif
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
......@@ -234,12 +215,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register
constexpr auto a_e_k_block_mtx_desc =
make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor(wei_e_k_block_desc);
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
constexpr auto b_e_n1bn2_block_mtx_desc =
make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor(
in_e_n1_b_n2_block_desc.Unfold(I1, I3));
make_ConstantMatrixDescriptor(in_e_n1_b_n2_block_desc.Unfold(I1, I3));
// sanity check
static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) ==
......@@ -252,7 +231,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_k0k2_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<N1 * N2>{});
Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<GemmNRepeat * GemmNPerThreadSubC>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize,
......@@ -302,17 +281,13 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
__syncthreads();
#if 0
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
#else
blockwise_in_copy.MoveSrcSlicingWindow({EPerBlock, 0, 0, 0}, true);
blockwise_wei_copy.MoveSrcSlicingWindow({EPerBlock, 0}, true);
#endif
blockwise_in_copy.MoveSrcSliceWindow(make_multi_index(EPerBlock, 0, 0, 0), True);
blockwise_wei_copy.MoveSrcSliceWindow(make_multi_index(EPerBlock, 0), True);
}
// copy output: register to global memory
{
#if 0
constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = GemmMLevel0Cluster * GemmMLevel1Cluster;
......@@ -358,26 +333,65 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex(
k_thread_data_on_global, 0, b_thread_data_on_global, 0);
#if 0 // debug
threadwise_generic_tensor_slice_copy_v1(
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc,
p_out_thread,
{0, 0, 0, 0, 0, 0, 0, 0},
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc,
p_out_thread_on_global,
{0, 0, 0, 0, 0, 0, 0, 0},
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths(),
arithmetic_sequence_gen<0, 8, 1>::type{},
Number<1>{});
#else
ThreadwiseGenericTensorSliceCopy_v2<
ThreadwiseGenericTensorSliceCopy_v2r1<
decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc),
decltype(out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc),
NormalTensorCoordinate<decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc)>,
MergedTensorCoordinate<decltype(out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc)>,
decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths())>(
{0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0})
decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths()),
arithmetic_sequence_gen<0, 8, 1>::type,
arithmetic_sequence_gen<0, 8, 1>::type,
7,
7,
1,
1>({0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0})
.Run(p_out_thread, p_out_thread_on_global);
#else
constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster;
// define tensor descriptor for threadwise copy
// output memory layout descriptor in register, src of threadwise copy
constexpr auto out_k0_k1_n1_b_n2_thread_mem_desc = make_ConstantTensorDescriptor_packed(
Sequence<GemmMRepeat, GemmMPerThreadSubC, N1, 1, N2>{});
// output memory layout descriptor in device memory
constexpr auto out_n0_n1_n2_k0_k1_h_w_global_mem_desc =
out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}).Fold(I0, Number<N1>{}, Number<N2>{});
// output merged global tensor descriptor, dst of threadwise copy
constexpr auto out_k0_k1_n1_b_n2_global_merged_desc =
make_ConstantMergedTensorDescriptor(out_n0_n1_n2_k0_k1_h_w_global_mem_desc,
Sequence<3>{},
Sequence<4>{},
Sequence<1>{},
Sequence<0, 5, 6>{},
Sequence<2>{});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t k_thread_data_on_global =
k_block_data_on_global + c_thread_mtx_on_block.row;
const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col / N2;
ThreadwiseGenericTensorSliceCopy_v2r1<
decltype(out_k0_k1_n1_b_n2_thread_mem_desc),
decltype(out_k0_k1_n1_b_n2_global_merged_desc),
decltype(out_k0_k1_n1_b_n2_thread_mem_desc.GetLengths()),
arithmetic_sequence_gen<0, 5, 1>::type,
arithmetic_sequence_gen<0, 5, 1>::type,
3,
3,
1,
1>({0, 0, 0, 0, 0},
{k_thread_data_on_global / K1,
k_thread_data_on_global % K1,
0,
b_thread_data_on_global,
0})
.template Run_amd_experiment<Float, 0, 2>(p_out_thread, p_out_global);
#endif
}
}
......
......@@ -181,12 +181,6 @@ struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer
InBlockCopyDataPerAccess_W2>({0, 0, 0, 0, b_block_data_on_global, 0, 0, 0},
{0, 0, 0, 0, 0, 0, 0, 0});
#if 0
{
printf("id (%d %d), in offset: %d %d\n", get_block_1d_id(), get_thread_local_1d_id(), blockwise_in_copy.mThreadSrcOffset, blockwise_in_copy.mThreadDstOffset);
}
#endif
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
constexpr auto wei_e_k_global_desc =
......@@ -222,8 +216,7 @@ struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register
constexpr auto a_e_k_block_mtx_desc =
make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor(wei_e_k_block_desc);
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
......@@ -233,8 +226,7 @@ struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer
"GemmDataPerReadB alignment requirement is not satisfied");
constexpr auto b_e_n0ho0wo0bn2ho2wo2_block_mtx_desc =
make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor(
in_e_n0_ho0_wo0_b_n2_ho2_wo2_block_desc.Unfold(I1, I7));
make_ConstantMatrixDescriptor(in_e_n0_ho0_wo0_b_n2_ho2_wo2_block_desc.Unfold(I1, I7));
// sanity check
static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) ==
......@@ -313,8 +305,8 @@ struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer
Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
......@@ -322,25 +314,23 @@ struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global, p_in_register_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_block_on_global,
p_wei_register_clipboard);
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global,
p_wei_register_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard,
p_in_block_next);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
p_wei_block_next);
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_next);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_next);
}
}
// LDS double buffer: tail
{
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
// even iteration
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
......@@ -349,18 +339,17 @@ struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global, p_in_register_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_block_on_global,
p_wei_register_clipboard);
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, p_wei_register_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard,
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
p_wei_block_double + wei_block_space);
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer,
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
p_wei_block_double + wei_block_space);
// odd iteration
__syncthreads();
......
......@@ -228,8 +228,7 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register
constexpr auto a_e_k_block_mtx_desc =
make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor(wei_e_k_block_desc);
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
......@@ -239,8 +238,7 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer
"GemmDataPerReadB alignment requirement is not satisfied");
constexpr auto b_e_n1ho1wo1bn2ho2wo2_block_mtx_desc =
make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor(
in_e_n1_ho1_wo1_b_n2_ho2_wo2_block_desc.Unfold(I1, I7));
make_ConstantMatrixDescriptor(in_e_n1_ho1_wo1_b_n2_ho2_wo2_block_desc.Unfold(I1, I7));
// sanity check
static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) ==
......@@ -319,8 +317,8 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer
Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
......@@ -328,9 +326,9 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global, p_in_register_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_block_on_global,
p_wei_register_clipboard);
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global,
p_wei_register_buffer);
#if 0
if(get_block_1d_id() == 0)
......@@ -338,10 +336,10 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer
printf("tid (%d %d), %f %f %f %f\n",
get_block_1d_id(),
get_thread_local_1d_id(),
p_wei_register_clipboard[0],
p_wei_register_clipboard[1],
p_wei_register_clipboard[2],
p_wei_register_clipboard[3]);
p_wei_register_buffer[0],
p_wei_register_buffer[1],
p_wei_register_buffer[2],
p_wei_register_buffer[3]);
}
#endif
......@@ -349,17 +347,15 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard,
p_in_block_next);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
p_wei_block_next);
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_next);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_next);
}
}
// LDS double buffer: tail
{
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
// even iteration
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
......@@ -368,18 +364,17 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global, p_in_register_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_block_on_global,
p_wei_register_clipboard);
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, p_wei_register_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard,
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
p_wei_block_double + wei_block_space);
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer,
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
p_wei_block_double + wei_block_space);
// odd iteration
__syncthreads();
......
......@@ -44,7 +44,8 @@ template <index_t GridSize,
class WeiBlockCopySrcAccessOrder,
class WeiBlockCopyDstAccessOrder,
index_t WeiBlockCopySrcDataPerRead_E,
index_t WeiBlockCopyDstDataPerWrite_K>
index_t WeiBlockCopyDstDataPerWrite_K,
index_t OutThreadCopyDataPerAccess_B>
struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
{
__device__ void Run(const Float* const __restrict__ p_in_global,
......@@ -82,7 +83,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
constexpr index_t E = C * Y * X;
constexpr index_t B = N * Ho * Wo;
static_assert((X == 1 || ConvDilationW % InBlockCopyDataPerAccess_B == 0),
// sanity-check for vectorized memory load
static_assert((Wo == 1 || (ConvStrideW == 1 || InBlockCopyDataPerAccess_B == 1)) &&
(X == 1 || ConvDilationW % InBlockCopyDataPerAccess_B == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
......@@ -133,12 +136,16 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
BlockwiseGenericTensorSliceCopy_v2<BlockSize,
decltype(in_e_b_global_desc),
decltype(in_e_b_block_desc),
MergedTensorCoordinate<decltype(in_e_b_global_desc)>,
NormalTensorCoordinate<decltype(in_e_b_block_desc)>,
decltype(in_e_b_block_desc.GetLengths()),
InBlockCopySubLengths_E_B,
InBlockCopyClusterLengths_E_B,
InBlockCopyThreadClusterArrangeOrder>(
InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder,
1,
1,
InBlockCopyDataPerAccess_B,
InBlockCopyDataPerAccess_B>(
{0, b_block_data_on_global}, {0, 0});
// weight tensor
......@@ -152,19 +159,30 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
Sequence<EPerBlock, KPerBlock>{},
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert(wei_e_k_block_desc.GetStride(I0) % GemmDataPerReadA == 0,
"GemmDataPerReadA alignment requirement is not satisfied");
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v2<
BlockSize,
decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc),
NormalTensorCoordinate<decltype(wei_e_k_global_desc)>,
NormalTensorCoordinate<decltype(wei_e_k_block_desc)>,
decltype(wei_e_k_block_desc.GetLengths()),
WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder>({0, k_block_data_on_global}, {0, 0});
auto blockwise_wei_copy =
BlockwiseGenericTensorSliceCopy_v2<BlockSize,
decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc),
decltype(wei_e_k_block_desc.GetLengths()),
WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
0,
1,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>(
{0, k_block_data_on_global}, {0, 0});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
......@@ -172,11 +190,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
// b_mtx[EPerBlocl, BPerBlock] is in LDS
// c_mtx[KPerBlock, BPerBlock] is distributed among threads, and saved in
// register
constexpr auto a_e_k_block_mtx_desc =
make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor(wei_e_k_block_desc);
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
constexpr auto b_e_b_block_mtx_desc =
make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor(in_e_b_block_desc);
constexpr auto b_e_b_block_mtx_desc = make_ConstantMatrixDescriptor(in_e_b_block_desc);
// sanity check
static_assert(
......@@ -242,8 +258,8 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
__syncthreads();
blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
blockwise_wei_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
}
// copy output: register to global memory
......@@ -285,23 +301,27 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
using OutThreadCopySliceLengths =
Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNPerThreadSubC>;
auto threadwise_out_copy = ThreadwiseGenericTensorSliceCopy_v2<
decltype(out_k0_k1_b_thread_desc),
decltype(out_k0_k1_b_global_desc),
NormalTensorCoordinate<decltype(out_k0_k1_b_thread_desc)>,
MergedTensorCoordinate<decltype(out_k0_k1_b_global_desc)>,
OutThreadCopySliceLengths>({0, 0, 0},
{k_thread_data_on_global / K1,
k_thread_data_on_global % K1,
b_thread_data_on_global});
auto threadwise_out_copy =
ThreadwiseGenericTensorSliceCopy_v2r1<decltype(out_k0_k1_b_thread_desc),
decltype(out_k0_k1_b_global_desc),
OutThreadCopySliceLengths,
arithmetic_sequence_gen<0, 3, 1>::type,
arithmetic_sequence_gen<0, 3, 1>::type,
2,
2,
OutThreadCopyDataPerAccess_B,
OutThreadCopyDataPerAccess_B>(
{0, 0, 0},
{k_thread_data_on_global / K1,
k_thread_data_on_global % K1,
b_thread_data_on_global});
for(index_t nrepeat = 0; nrepeat < GemmNRepeat; ++nrepeat)
{
threadwise_out_copy.Run(p_out_thread, p_out_global);
threadwise_out_copy.MoveSrcSlicingWindow(Sequence<0, 0, GemmNPerThreadSubC>{},
True);
threadwise_out_copy.MoveDstSlicingWindow(Sequence<0, 0, B1>{}, True);
threadwise_out_copy.MoveSrcSliceWindow(Sequence<0, 0, GemmNPerThreadSubC>{}, True);
threadwise_out_copy.MoveDstSliceWindow(Sequence<0, 0, B1>{}, True);
}
}
}
......
......@@ -5,9 +5,9 @@
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_generic_tensor_slice_copy_deprecated.hpp"
#include "blockwise_gemm.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy_deprecated.hpp"
namespace ck {
......@@ -83,7 +83,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
constexpr index_t E = C * Y * X;
constexpr index_t B = N * Ho * Wo;
static_assert((X == 1 || ConvDilationW % InBlockCopyDataPerAccess_B == 0),
// sanity-check for vectorized memory load
static_assert((Wo == 1 || (ConvStrideW == 1 || InBlockCopyDataPerAccess_B == 1)) &&
(X == 1 || ConvDilationW % InBlockCopyDataPerAccess_B == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
......@@ -134,8 +136,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
BlockwiseGenericTensorSliceCopy_v2<BlockSize,
decltype(in_e_b_global_desc),
decltype(in_e_b_block_desc),
MergedTensorCoordinate<decltype(in_e_b_global_desc)>,
NormalTensorCoordinate<decltype(in_e_b_block_desc)>,
decltype(in_e_b_block_desc.GetLengths()),
InBlockCopySubLengths_E_B,
InBlockCopyClusterLengths_E_B,
......@@ -159,25 +159,30 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
Sequence<EPerBlock, KPerBlock>{},
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert(wei_e_k_block_desc.GetStride(I0) % GemmDataPerReadA == 0,
"GemmDataPerReadA alignment requirement is not satisfied");
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v2<
BlockSize,
decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc),
NormalTensorCoordinate<decltype(wei_e_k_global_desc)>,
NormalTensorCoordinate<decltype(wei_e_k_block_desc)>,
decltype(wei_e_k_block_desc.GetLengths()),
WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
0,
1,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>({0, k_block_data_on_global}, {0, 0});
auto blockwise_wei_copy =
BlockwiseGenericTensorSliceCopy_v2<BlockSize,
decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc),
decltype(wei_e_k_block_desc.GetLengths()),
WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
0,
1,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>(
{0, k_block_data_on_global}, {0, 0});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
......@@ -185,11 +190,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
// b_mtx[EPerBlocl, BPerBlock] is in LDS
// c_mtx[KPerBlock, BPerBlock] is distributed among threads, and saved in
// register
constexpr auto a_e_k_block_mtx_desc =
make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor(wei_e_k_block_desc);
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
constexpr auto b_e_b_block_mtx_desc =
make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor(in_e_b_block_desc);
constexpr auto b_e_b_block_mtx_desc = make_ConstantMatrixDescriptor(in_e_b_block_desc);
// sanity check
static_assert(
......@@ -248,8 +251,10 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
// LDS double buffer: preload data into LDS
{
blockwise_in_copy.Run(p_in_global, p_in_block_double);
blockwise_wei_copy.Run(p_wei_global, p_wei_block_double);
blockwise_in_copy.template Run<Float, address_space_t::global>(p_in_global,
p_in_block_double);
blockwise_wei_copy.template Run<Float, address_space_t::global>(p_wei_global,
p_wei_block_double);
}
// LDS double buffer: main body
......@@ -271,51 +276,54 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStrides()[0];
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global,
p_wei_register_buffer);
blockwise_in_copy.template RunLoadThreadBuffer<Float, address_space_t::global>(
p_in_global, p_in_thread_buffer);
blockwise_wei_copy.template RunLoadThreadBuffer<Float, address_space_t::global>(
p_wei_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_next);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_next);
blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer, p_in_block_next);
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer, p_wei_block_next);
}
}
// LDS double buffer: tail
{
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
// even iteration
blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStrides()[0];
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, p_wei_register_buffer);
blockwise_in_copy.template RunLoadThreadBuffer<Float, address_space_t::global>(
p_in_global, p_in_thread_buffer);
blockwise_wei_copy.template RunLoadThreadBuffer<Float, address_space_t::global>(
p_wei_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer,
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
p_wei_block_double + wei_block_space);
blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer,
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer,
p_wei_block_double + wei_block_space);
// odd iteration
__syncthreads();
......@@ -365,29 +373,29 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
using OutThreadCopySliceLengths =
Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNPerThreadSubC>;
auto threadwise_out_copy = ThreadwiseGenericTensorSliceCopy_v2r1<
decltype(out_k0_k1_b_thread_desc),
decltype(out_k0_k1_b_global_desc),
NormalTensorCoordinate<decltype(out_k0_k1_b_thread_desc)>,
MergedTensorCoordinate<decltype(out_k0_k1_b_global_desc)>,
OutThreadCopySliceLengths,
arithmetic_sequence_gen<0, 3, 1>::type,
arithmetic_sequence_gen<0, 3, 1>::type,
2,
2,
OutThreadCopyDataPerAccess_B,
OutThreadCopyDataPerAccess_B>({0, 0, 0},
{k_thread_data_on_global / K1,
k_thread_data_on_global % K1,
b_thread_data_on_global});
auto threadwise_out_copy =
ThreadwiseGenericTensorSliceCopy_v2r1<decltype(out_k0_k1_b_thread_desc),
decltype(out_k0_k1_b_global_desc),
OutThreadCopySliceLengths,
arithmetic_sequence_gen<0, 3, 1>::type,
arithmetic_sequence_gen<0, 3, 1>::type,
2,
2,
OutThreadCopyDataPerAccess_B,
OutThreadCopyDataPerAccess_B>(
{0, 0, 0},
{k_thread_data_on_global / K1,
k_thread_data_on_global % K1,
b_thread_data_on_global});
for(index_t nrepeat = 0; nrepeat < GemmNRepeat; ++nrepeat)
{
threadwise_out_copy.Run(p_out_thread, p_out_global);
threadwise_out_copy
.template Run<Float, address_space_t::generic, address_space_t::global>(
p_out_thread, p_out_global);
threadwise_out_copy.MoveSrcSlicingWindow(Sequence<0, 0, GemmNPerThreadSubC>{},
True);
threadwise_out_copy.MoveDstSlicingWindow(Sequence<0, 0, B1>{}, True);
threadwise_out_copy.MoveSrcSliceWindow(Sequence<0, 0, GemmNPerThreadSubC>{}, True);
threadwise_out_copy.MoveDstSliceWindow(Sequence<0, 0, B1>{}, True);
}
}
}
......
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_PADDED_LDS_DOUBLE_BUFFER_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_PADDED_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
namespace ck {
// B = merge(N, Ho, Wo)
template <index_t GridSize,
index_t BlockSize,
typename Float,
typename InGlobalDesc,
typename WeiGlobalDesc,
typename OutGlobalDesc,
typename ConvStrides,
typename ConvDilations,
typename LeftPads,
typename RightPads,
index_t BPerBlock,
index_t KPerBlock,
index_t EPerBlock,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
typename InBlockCopySubLengths_E_B,
typename InBlockCopyClusterLengths_E_B,
typename InBlockCopyThreadClusterArrangeOrder,
typename InBlockCopySrcAccessOrder,
typename InBlockCopyDstAccessOrder,
index_t InBlockCopyDataPerAccess_B,
typename WeiBlockCopySubLengths_E_K,
typename WeiBlockCopyClusterLengths_E_K,
typename WeiBlockCopyThreadClusterArrangeOrder,
typename WeiBlockCopySrcAccessOrder,
typename WeiBlockCopyDstAccessOrder,
index_t WeiBlockCopySrcDataPerRead_E,
index_t WeiBlockCopyDstDataPerWrite_K,
index_t OutThreadCopyDataPerAccess_B>
struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buffer
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto in_n_c_hi_wi_global_desc =
make_native_tensor_descriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides());
constexpr auto wei_k_c_y_x_global_desc =
make_native_tensor_descriptor(WeiGlobalDesc::GetLengths(), WeiGlobalDesc::GetStrides());
constexpr auto out_n_k_ho_wo_global_desc =
make_native_tensor_descriptor(OutGlobalDesc::GetLengths(), OutGlobalDesc::GetStrides());
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLength(I0);
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLength(I1);
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLength(I1);
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t E = C * Y * X;
constexpr index_t B = N * Ho * Wo;
// sanity-check for vectorized memory load
static_assert((Wo == 1 || (ConvStrideW == 1 || InBlockCopyDataPerAccess_B == 1)) &&
(X == 1 || ConvDilationW % InBlockCopyDataPerAccess_B == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
// divide block work by [K, B]
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % (2 * EPerBlock) == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t KBlockWork = K / KPerBlock;
constexpr index_t BBlockWork = B / BPerBlock;
constexpr auto block_work_desc =
make_cluster_descriptor(Sequence<KBlockWork, BBlockWork>{});
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
const index_t k_block_data_on_global = block_work_id[0] * KPerBlock;
const index_t b_block_data_on_global = block_work_id[1] * BPerBlock;
// input tensor
// global mem
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(
PassThrough<N>{}, PassThrough<C>{}, Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto in_e_b_global_desc = transform_tensor_descriptor(
in_n_c_y_ho_x_wo_global_desc,
make_tuple(Merge<Sequence<C, Y, X>>{}, Merge<Sequence<N, Ho, Wo>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// LDS mem
// be careful of LDS alignment
constexpr auto in_e_b_block_desc =
make_native_tensor_descriptor_packed(Sequence<EPerBlock, BPerBlock>{});
// input blockwise copy
auto blockwise_in_copy =
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(in_e_b_global_desc),
decltype(in_e_b_block_desc),
decltype(in_e_b_block_desc.GetLengths()),
InBlockCopySubLengths_E_B,
InBlockCopyClusterLengths_E_B,
InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder,
1,
1,
InBlockCopyDataPerAccess_B,
InBlockCopyDataPerAccess_B>(
{0, b_block_data_on_global}, {0, 0});
// weight tensor
// global mem
constexpr auto wei_e_k_global_desc = reorder_tensor_descriptor_given_upper2lower(
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3), Sequence<1, 0>{});
// LDS
// be careful of LDS alignment
constexpr auto wei_e_k_block_desc = make_native_tensor_descriptor_aligned(
Sequence<EPerBlock, KPerBlock>{},
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert(wei_e_k_block_desc.GetStride(I0) % GemmDataPerReadA == 0,
"GemmDataPerReadA alignment requirement is not satisfied");
// weight blockwise copy
auto blockwise_wei_copy =
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc),
decltype(wei_e_k_block_desc.GetLengths()),
WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
0,
1,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>(
{0, k_block_data_on_global}, {0, 0});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[EPerBlock, KPerBlock] is in LDS
// b_mtx[EPerBlocl, BPerBlock] is in LDS
// c_mtx[KPerBlock, BPerBlock] is distributed among threads, and saved in
// register
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
constexpr auto b_e_b_block_mtx_desc = make_ConstantMatrixDescriptor(in_e_b_block_desc);
// sanity check
static_assert(
KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) == 0 &&
BPerBlock % (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) == 0,
"wrong!");
constexpr index_t GemmMRepeat =
KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster);
constexpr index_t GemmNRepeat =
BPerBlock / (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_k0k1_b0b1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<GemmNRepeat * GemmNPerThreadSubC>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize,
decltype(a_e_k_block_mtx_desc),
decltype(b_e_b_block_mtx_desc),
decltype(c_k0k1_b0b1_thread_mtx_desc),
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB>{};
// LDS allocation for input and weight: be careful of alignment
constexpr index_t max_align = math::lcm(InBlockCopyDataPerAccess_B,
WeiBlockCopyDstDataPerWrite_K,
GemmDataPerReadA,
GemmDataPerReadB);
constexpr index_t in_block_space =
math::integer_least_multiple(in_e_b_block_desc.GetElementSpace(), max_align);
constexpr index_t wei_block_space =
math::integer_least_multiple(wei_e_k_block_desc.GetElementSpace(), max_align);
__shared__ Float p_in_block_double[2 * in_block_space];
__shared__ Float p_wei_block_double[2 * wei_block_space];
// register allocation for output
Float p_out_thread[c_k0k1_b0b1_thread_mtx_desc.GetElementSpace()];
// zero out threadwise output
threadwise_matrix_set_zero(c_k0k1_b0b1_thread_mtx_desc, p_out_thread);
// LDS double buffer: preload data into LDS
{
blockwise_in_copy.template Run<Float, Float, address_space_t::global>(
p_in_global, p_in_block_double);
blockwise_wei_copy.template Run<Float, Float, address_space_t::global>(
p_wei_global, p_wei_block_double);
}
// LDS double buffer: main body
for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E;
e_block_data_begin += 2 * EPerBlock)
{
#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
Float* p_in_block_now =
even_loop ? p_in_block_double : p_in_block_double + in_block_space;
Float* p_wei_block_now =
even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space;
Float* p_in_block_next =
even_loop ? p_in_block_double + in_block_space : p_in_block_double;
Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.template RunLoadThreadBuffer<Float, Float, address_space_t::global>(
p_in_global, p_in_thread_buffer);
blockwise_wei_copy
.template RunLoadThreadBuffer<Float, Float, address_space_t::global>(
p_wei_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer, p_in_block_next);
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer, p_wei_block_next);
}
}
// LDS double buffer: tail
{
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
// even iteration
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.template RunLoadThreadBuffer<Float, Float, address_space_t::global>(
p_in_global, p_in_thread_buffer);
blockwise_wei_copy.template RunLoadThreadBuffer<Float, Float, address_space_t::global>(
p_wei_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer,
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer,
p_wei_block_double + wei_block_space);
// odd iteration
__syncthreads();
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double + wei_block_space,
p_in_block_double + in_block_space,
p_out_thread);
}
// copy output: register to global memory
{
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t k_thread_data_on_global =
k_block_data_on_global + c_thread_mtx_on_block.row;
const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col;
// src descriptor
constexpr auto out_k0_k1_b0_b1_thread_desc = make_native_tensor_descriptor_packed(
Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNRepeat, GemmNPerThreadSubC>{});
// dst descriptor
constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster;
constexpr index_t B1 = GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster;
constexpr index_t K0 = K / K1;
constexpr index_t B0 = B / B1;
constexpr auto out_k_b_global_desc = transform_tensor_descriptor(
out_n_k_ho_wo_global_desc,
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho, Wo>>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
constexpr auto out_k0_k1_b0_b1_global_desc = transform_tensor_descriptor(
out_k_b_global_desc,
make_tuple(UnMerge<Sequence<K0, K1>>{}, UnMerge<Sequence<B0, B1>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
// output threadwise copy
ThreadwiseGenericTensorSliceCopy_v4r2<
decltype(out_k0_k1_b0_b1_thread_desc),
decltype(out_k0_k1_b0_b1_global_desc),
decltype(out_k0_k1_b0_b1_thread_desc.GetLengths()),
arithmetic_sequence_gen<0, 4, 1>::type,
3,
OutThreadCopyDataPerAccess_B,
OutThreadCopyDataPerAccess_B>({0, 0, 0, 0},
{k_thread_data_on_global / K1,
k_thread_data_on_global % K1,
b_thread_data_on_global / B1,
b_thread_data_on_global % B1})
#if 1
.template Run<Float, Float, address_space_t::generic, address_space_t::global>
#else // tweaking
.template Run_optimized_dst_address_calculation<Float,
Float,
address_space_t::generic,
address_space_t::global>
#endif
(p_out_thread, p_out_global);
}
}
};
} // namespace ck
#endif
......@@ -3,6 +3,7 @@
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "tensor_descriptor.hpp"
namespace ck {
......@@ -52,10 +53,21 @@ __host__ __device__ constexpr auto
return ConstantMatrixDescriptor<NRow, NCol, RowStride>{};
}
template <class TDesc>
__host__ __device__ constexpr auto
make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor(TDesc)
template <typename... Ts>
__host__ __device__ constexpr auto make_ConstantMatrixDescriptor(ConstantTensorDescriptor<Ts...>)
{
using TDesc = ConstantTensorDescriptor<Ts...>;
static_assert(TDesc::GetNumOfDimension() == 2, "wrong");
static_assert(TDesc::GetStrides()[1] == 1, "wrong");
return ConstantMatrixDescriptor<TDesc::GetLengths()[0],
TDesc::GetLengths()[1],
TDesc::GetStrides()[0]>{};
}
template <typename... Ts>
__host__ __device__ constexpr auto make_ConstantMatrixDescriptor(NativeTensorDescriptor<Ts...>)
{
using TDesc = NativeTensorDescriptor<Ts...>;
static_assert(TDesc::GetNumOfDimension() == 2, "wrong");
static_assert(TDesc::GetStrides()[1] == 1, "wrong");
return ConstantMatrixDescriptor<TDesc::GetLengths()[0],
......@@ -63,7 +75,7 @@ __host__ __device__ constexpr auto
TDesc::GetStrides()[0]>{};
}
template <class TDesc>
template <typename TDesc>
__host__ __device__ void print_ConstantMatrixDescriptor(TDesc, const char* s)
{
printf(
......
......@@ -111,7 +111,7 @@ struct ConstantMergedTensorDescriptor
index_t itmp = original_multi_id_partial[I];
original_multi_id.Set(Number<idim_original>{}, itmp);
original_multi_id(idim_original) = itmp;
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment