Commit cb78cc74 authored by Chao Liu's avatar Chao Liu
Browse files

added implicit gemm v4 (nchw, kcyx)

parent b2439ec9
......@@ -71,11 +71,11 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
using WeiBlockCopySubLengths_E_K = Sequence<1, 4>;
using WeiBlockCopyClusterLengths_E_K = Sequence<8, 32>;
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>;
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>;
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
constexpr index_t WeiBlockCopySrcDataPerRead_E = 1;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 4;
#endif
constexpr index_t GridSize =
......
......@@ -443,7 +443,7 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
#elif 1
// 3x3 filter, 28x28 image
constexpr index_t N = 128;
constexpr index_t C = 256;
......@@ -539,7 +539,7 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 1
#elif 0
// 1x1 filter, 14x14 image
constexpr index_t N = 128;
constexpr index_t C = 512;
......@@ -551,7 +551,7 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 1
#elif 0
// 1x1 filter, 73x73 image
constexpr index_t N = 128;
constexpr index_t C = 64;
......@@ -634,7 +634,7 @@ int main(int argc, char* argv[])
device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw
#elif 0
device_convolution_implicit_gemm_v2_chwn_cyxk_khwn
#elif 1
#elif 0
device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw
#elif 1
device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw
......@@ -655,7 +655,7 @@ int main(int argc, char* argv[])
if(do_verification)
{
#if 0
#if 1
if(Y == 3 && X == 3)
{
host_winograd_3x3_convolution(in_nchw, wei_kcyx, out_nkhw_host, lower_pads, upper_pads);
......
......@@ -215,7 +215,7 @@ struct ConstantTensorDescriptor
// do carry check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for<0, nDim - 1, 1>{}([&](auto IDimReverse) {
static_for<0, nDim, 1>{}([&](auto IDimReverse) {
constexpr index_t idim = nDim - 1 - IDimReverse.Get();
constexpr auto IDim = Number<idim>{};
......@@ -241,7 +241,7 @@ struct ConstantTensorDescriptor
// do borrow check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for<0, nDim - 1, 1>{}([&](auto IDimReverse) {
static_for<0, nDim, 1>{}([&](auto IDimReverse) {
constexpr index_t idim = nDim - 1 - IDimReverse.Get();
constexpr auto IDim = Number<idim>{};
......
......@@ -3,7 +3,7 @@
// slice a (normal or merged) tensor, and copy it into another (normal or merged) tensor
// memory layout (ordering of dimensions) can be different between src and dst
// For now, only support SubLengths == 1 on a merged dimension
// For now, only support SubLengths[...] == 1 on a merged dimension
template <index_t BlockSize,
class Float,
class SrcDesc,
......@@ -84,8 +84,8 @@ struct BlockwiseGenericTensorSliceCopy_v1
constexpr auto repeat_lengths = SliceLengths{} / data_per_cluster_per_dims;
// for now, only support SubLengths.Get() == 1 on a merged dimension that is merge from
// multiple dimensions
// for now, only support SubLengths.Get() == 1 on a merged dimension that constains
// multiple original dimensions
static_for<0, nDim, 1>{}([&](auto IDim_) {
constexpr auto IDim = decltype(IDim_){};
......@@ -292,7 +292,8 @@ struct BlockwiseGenericTensorSliceCopy_v1
static_if<SrcDesc::ContainMultipleOriginalDimensions(IDim)>{}([&](auto fwd) {
// logic for a merged dimension, also works for non-merged dimension, but its logic may
// be unncessarily complicated for compiler to remove uselss calculations
// be unncessarily complicated for compiler to remove calculations that are useless for
// a non-merged dimension
// extract partial original dimensions
constexpr auto src_partial_original_dims =
......@@ -309,6 +310,27 @@ struct BlockwiseGenericTensorSliceCopy_v1
src_partial_original_desc.UpdateMultiIndexGivenStepSizeOf1dIndex(
old_src_partial_original_multi_id, StepSize, direction);
#if 0
{
if(debug_flag && get_block_1d_id() == 0)
{
printf("id %5u %5u: "
"old_src_partial_original_multi_id %u %u %u, "
"new_src_partial_original_multi_id %u %u %u, "
"mThreadSrcOffset %u, mThreadDstOffset %u \n",
get_block_1d_id(),
get_thread_local_1d_id(),
old_src_partial_original_multi_id[0],
old_src_partial_original_multi_id[1],
old_src_partial_original_multi_id[2],
new_src_partial_original_multi_id[0],
new_src_partial_original_multi_id[1],
new_src_partial_original_multi_id[2]
);
}
}
#endif
// update "mThreadSrcOriginalMultiId"
static_for<0, src_partial_original_dims.GetSize(), 1>{}([&](auto I_) {
constexpr auto I = decltype(I_){};
......
......@@ -255,28 +255,14 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw
for(index_t e = 0; e < E; e += EPerBlock)
{
#if 0
if(e == 1 * EPerBlock && get_block_1d_id() == 0)
if(e == 0 * EPerBlock && get_block_1d_id() == 0)
{
printf("id %5u %5u: "
"mThreadSrcOriginalMultiId %u %u %u %u %u %u %u %u, "
"mThreadSrcPartialOffsets %u %u %u %u, "
"mThreadSrcOffset %u, mThreadDstOffset %u \n",
get_block_1d_id(),
get_thread_local_1d_id(),
blockwise_in_copy.mThreadSrcOriginalMultiId[0],
blockwise_in_copy.mThreadSrcOriginalMultiId[1],
blockwise_in_copy.mThreadSrcOriginalMultiId[2],
blockwise_in_copy.mThreadSrcOriginalMultiId[3],
blockwise_in_copy.mThreadSrcOriginalMultiId[4],
blockwise_in_copy.mThreadSrcOriginalMultiId[5],
blockwise_in_copy.mThreadSrcOriginalMultiId[6],
blockwise_in_copy.mThreadSrcOriginalMultiId[7],
blockwise_in_copy.mThreadSrcPartialOffsets[0],
blockwise_in_copy.mThreadSrcPartialOffsets[1],
blockwise_in_copy.mThreadSrcPartialOffsets[2],
blockwise_in_copy.mThreadSrcPartialOffsets[3],
blockwise_in_copy.mThreadSrcOffset,
blockwise_in_copy.mThreadDstOffset);
blockwise_wei_copy.mThreadSrcOffset,
blockwise_wei_copy.mThreadDstOffset);
}
#endif
// marching slicing window
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment