Commit cb78cc74 authored by Chao Liu's avatar Chao Liu
Browse files

added implicit gemm v4 (nchw, kcyx)

parent b2439ec9
...@@ -71,11 +71,11 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc, ...@@ -71,11 +71,11 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
constexpr index_t InBlockCopySrcDataPerRead_B = 1; constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4; constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
using WeiBlockCopySubLengths_E_K = Sequence<1, 4>; using WeiBlockCopySubLengths_E_K = Sequence<4, 1>;
using WeiBlockCopyClusterLengths_E_K = Sequence<8, 32>; using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>;
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; constexpr index_t WeiBlockCopySrcDataPerRead_E = 1;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; constexpr index_t WeiBlockCopyDstDataPerWrite_K = 4;
#endif #endif
constexpr index_t GridSize = constexpr index_t GridSize =
......
...@@ -443,7 +443,7 @@ int main(int argc, char* argv[]) ...@@ -443,7 +443,7 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0; constexpr index_t HPad = 0;
constexpr index_t WPad = 0; constexpr index_t WPad = 0;
#elif 0 #elif 1
// 3x3 filter, 28x28 image // 3x3 filter, 28x28 image
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 256; constexpr index_t C = 256;
...@@ -539,7 +539,7 @@ int main(int argc, char* argv[]) ...@@ -539,7 +539,7 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0; constexpr index_t HPad = 0;
constexpr index_t WPad = 0; constexpr index_t WPad = 0;
#elif 1 #elif 0
// 1x1 filter, 14x14 image // 1x1 filter, 14x14 image
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 512; constexpr index_t C = 512;
...@@ -551,7 +551,7 @@ int main(int argc, char* argv[]) ...@@ -551,7 +551,7 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0; constexpr index_t HPad = 0;
constexpr index_t WPad = 0; constexpr index_t WPad = 0;
#elif 1 #elif 0
// 1x1 filter, 73x73 image // 1x1 filter, 73x73 image
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 64; constexpr index_t C = 64;
...@@ -634,7 +634,7 @@ int main(int argc, char* argv[]) ...@@ -634,7 +634,7 @@ int main(int argc, char* argv[])
device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw
#elif 0 #elif 0
device_convolution_implicit_gemm_v2_chwn_cyxk_khwn device_convolution_implicit_gemm_v2_chwn_cyxk_khwn
#elif 1 #elif 0
device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw
#elif 1 #elif 1
device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw
...@@ -655,7 +655,7 @@ int main(int argc, char* argv[]) ...@@ -655,7 +655,7 @@ int main(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
#if 0 #if 1
if(Y == 3 && X == 3) if(Y == 3 && X == 3)
{ {
host_winograd_3x3_convolution(in_nchw, wei_kcyx, out_nkhw_host, lower_pads, upper_pads); host_winograd_3x3_convolution(in_nchw, wei_kcyx, out_nkhw_host, lower_pads, upper_pads);
......
...@@ -215,7 +215,7 @@ struct ConstantTensorDescriptor ...@@ -215,7 +215,7 @@ struct ConstantTensorDescriptor
// do carry check in reversed order, starting from lowest dimension // do carry check in reversed order, starting from lowest dimension
// don't check the highest dimension // don't check the highest dimension
static_for<0, nDim - 1, 1>{}([&](auto IDimReverse) { static_for<0, nDim, 1>{}([&](auto IDimReverse) {
constexpr index_t idim = nDim - 1 - IDimReverse.Get(); constexpr index_t idim = nDim - 1 - IDimReverse.Get();
constexpr auto IDim = Number<idim>{}; constexpr auto IDim = Number<idim>{};
...@@ -241,7 +241,7 @@ struct ConstantTensorDescriptor ...@@ -241,7 +241,7 @@ struct ConstantTensorDescriptor
// do borrow check in reversed order, starting from lowest dimension // do borrow check in reversed order, starting from lowest dimension
// don't check the highest dimension // don't check the highest dimension
static_for<0, nDim - 1, 1>{}([&](auto IDimReverse) { static_for<0, nDim, 1>{}([&](auto IDimReverse) {
constexpr index_t idim = nDim - 1 - IDimReverse.Get(); constexpr index_t idim = nDim - 1 - IDimReverse.Get();
constexpr auto IDim = Number<idim>{}; constexpr auto IDim = Number<idim>{};
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
// slice a (normal or merged) tensor, and copy it into another (normal or merged) tensor // slice a (normal or merged) tensor, and copy it into another (normal or merged) tensor
// memory layout (ordering of dimensions) can be different between src and dst // memory layout (ordering of dimensions) can be different between src and dst
// For now, only support SubLengths == 1 on a merged dimension // For now, only support SubLengths[...] == 1 on a merged dimension
template <index_t BlockSize, template <index_t BlockSize,
class Float, class Float,
class SrcDesc, class SrcDesc,
...@@ -84,8 +84,8 @@ struct BlockwiseGenericTensorSliceCopy_v1 ...@@ -84,8 +84,8 @@ struct BlockwiseGenericTensorSliceCopy_v1
constexpr auto repeat_lengths = SliceLengths{} / data_per_cluster_per_dims; constexpr auto repeat_lengths = SliceLengths{} / data_per_cluster_per_dims;
// for now, only support SubLengths.Get() == 1 on a merged dimension that is merge from // for now, only support SubLengths.Get() == 1 on a merged dimension that constains
// multiple dimensions // multiple original dimensions
static_for<0, nDim, 1>{}([&](auto IDim_) { static_for<0, nDim, 1>{}([&](auto IDim_) {
constexpr auto IDim = decltype(IDim_){}; constexpr auto IDim = decltype(IDim_){};
...@@ -292,7 +292,8 @@ struct BlockwiseGenericTensorSliceCopy_v1 ...@@ -292,7 +292,8 @@ struct BlockwiseGenericTensorSliceCopy_v1
static_if<SrcDesc::ContainMultipleOriginalDimensions(IDim)>{}([&](auto fwd) { static_if<SrcDesc::ContainMultipleOriginalDimensions(IDim)>{}([&](auto fwd) {
// logic for a merged dimension, also works for non-merged dimension, but its logic may // logic for a merged dimension, also works for non-merged dimension, but its logic may
// be unncessarily complicated for compiler to remove uselss calculations // be unncessarily complicated for compiler to remove calculations that are useless for
// a non-merged dimension
// extract partial original dimensions // extract partial original dimensions
constexpr auto src_partial_original_dims = constexpr auto src_partial_original_dims =
...@@ -309,6 +310,27 @@ struct BlockwiseGenericTensorSliceCopy_v1 ...@@ -309,6 +310,27 @@ struct BlockwiseGenericTensorSliceCopy_v1
src_partial_original_desc.UpdateMultiIndexGivenStepSizeOf1dIndex( src_partial_original_desc.UpdateMultiIndexGivenStepSizeOf1dIndex(
old_src_partial_original_multi_id, StepSize, direction); old_src_partial_original_multi_id, StepSize, direction);
#if 0
{
if(debug_flag && get_block_1d_id() == 0)
{
printf("id %5u %5u: "
"old_src_partial_original_multi_id %u %u %u, "
"new_src_partial_original_multi_id %u %u %u, "
"mThreadSrcOffset %u, mThreadDstOffset %u \n",
get_block_1d_id(),
get_thread_local_1d_id(),
old_src_partial_original_multi_id[0],
old_src_partial_original_multi_id[1],
old_src_partial_original_multi_id[2],
new_src_partial_original_multi_id[0],
new_src_partial_original_multi_id[1],
new_src_partial_original_multi_id[2]
);
}
}
#endif
// update "mThreadSrcOriginalMultiId" // update "mThreadSrcOriginalMultiId"
static_for<0, src_partial_original_dims.GetSize(), 1>{}([&](auto I_) { static_for<0, src_partial_original_dims.GetSize(), 1>{}([&](auto I_) {
constexpr auto I = decltype(I_){}; constexpr auto I = decltype(I_){};
......
...@@ -255,28 +255,14 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw ...@@ -255,28 +255,14 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw
for(index_t e = 0; e < E; e += EPerBlock) for(index_t e = 0; e < E; e += EPerBlock)
{ {
#if 0 #if 0
if(e == 1 * EPerBlock && get_block_1d_id() == 0) if(e == 0 * EPerBlock && get_block_1d_id() == 0)
{ {
printf("id %5u %5u: " printf("id %5u %5u: "
"mThreadSrcOriginalMultiId %u %u %u %u %u %u %u %u, "
"mThreadSrcPartialOffsets %u %u %u %u, "
"mThreadSrcOffset %u, mThreadDstOffset %u \n", "mThreadSrcOffset %u, mThreadDstOffset %u \n",
get_block_1d_id(), get_block_1d_id(),
get_thread_local_1d_id(), get_thread_local_1d_id(),
blockwise_in_copy.mThreadSrcOriginalMultiId[0], blockwise_wei_copy.mThreadSrcOffset,
blockwise_in_copy.mThreadSrcOriginalMultiId[1], blockwise_wei_copy.mThreadDstOffset);
blockwise_in_copy.mThreadSrcOriginalMultiId[2],
blockwise_in_copy.mThreadSrcOriginalMultiId[3],
blockwise_in_copy.mThreadSrcOriginalMultiId[4],
blockwise_in_copy.mThreadSrcOriginalMultiId[5],
blockwise_in_copy.mThreadSrcOriginalMultiId[6],
blockwise_in_copy.mThreadSrcOriginalMultiId[7],
blockwise_in_copy.mThreadSrcPartialOffsets[0],
blockwise_in_copy.mThreadSrcPartialOffsets[1],
blockwise_in_copy.mThreadSrcPartialOffsets[2],
blockwise_in_copy.mThreadSrcPartialOffsets[3],
blockwise_in_copy.mThreadSrcOffset,
blockwise_in_copy.mThreadDstOffset);
} }
#endif #endif
// marching slicing window // marching slicing window
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment