Commit 7bbcd0fe authored by Jing Zhang's avatar Jing Zhang
Browse files

vector type demo for fp32

parent bbdb77e8
...@@ -96,17 +96,23 @@ struct ThreadwiseGenericTensorSliceCopy_v5 ...@@ -96,17 +96,23 @@ struct ThreadwiseGenericTensorSliceCopy_v5
template <typename SrcCoord> template <typename SrcCoord>
__device__ static float run(const float* p_src, const SrcCoord src_coord_begin) __device__ static float run(const float* p_src, const SrcCoord src_coord_begin)
{ {
constexpr auto vector_access_dim = Number<SrcDstVectorReadWriteDim>{};
auto scalar_id = make_zero_array<index_t, nDim>();
float r; float r;
r = load_data<float, float>(p_src, src_coord_begin.GetOffset());
return r;
}
};
scalar_id(vector_access_dim) = 0; template <typename DstData, index_t DstDataPerAccess, index_t VectorSize>
auto src_coord = src_coord_begin + scalar_id; struct vector_data_store;
r = load_data<float, float>(p_src, src_coord.GetOffset());
return r; template <>
struct vector_data_store<float, 1, 1>
{
template <typename DstCoord>
__device__ static void
run(float* p_dst, const float src_data, const DstCoord dst_coord_begin)
{
store_data<float, float>(src_data, p_dst, dst_coord_begin.GetOffset());
} }
}; };
...@@ -131,56 +137,20 @@ struct ThreadwiseGenericTensorSliceCopy_v5 ...@@ -131,56 +137,20 @@ struct ThreadwiseGenericTensorSliceCopy_v5
Number<vector_access_dim>{}, Number<vector_access_dim>{},
Number<long_vector_size * long_vector_access_id[vector_access_dim]>{}); Number<long_vector_size * long_vector_access_id[vector_access_dim]>{});
// buffer to hold a src long-vector
SrcData long_vector[long_vector_size];
#if 1
// zero out buffer
static_for<0, long_vector_size, 1>{}([&](auto i) { long_vector[i] = 0; });
#endif
// load data from src to the long-vector buffer // load data from src to the long-vector buffer
static_for<0, long_vector_size / src_data_per_access, 1>{}([&](auto i) { const auto src_coord = mSrcSliceOrigin + to_multi_index(long_vector_data_begin_id);
auto scalar_id = make_zero_multi_index<nDim>();
scalar_id(vector_access_dim) = i * src_data_per_access;
const index_t buffer_offset = i * src_data_per_access;
const auto src_coord =
mSrcSliceOrigin + (to_multi_index(long_vector_data_begin_id) + scalar_id);
// Check src data's valid mapping situation, only check the first data in this
// src
// vector. It's user's responsiblity to make sure all data in the src vector
// has the valid/invalid mapping situation
transfer_data<SrcData,
SrcDataPerRead,
SrcAddressSpace,
AddressSpace::Vgpr,
InMemoryDataOperation::Set,
SrcDataStride,
1>(p_src,
src_coord.GetOffset(),
src_coord.IsOffsetValidAssumingUpperIndexIsValid(),
SrcDesc::GetElementSpace(),
long_vector,
buffer_offset,
true,
long_vector_size);
});
// store data from the long-vector buffer to dst auto src_buff = vector_data_load<SrcData, SrcDataPerRead, long_vector_size>::run(
static_for<0, long_vector_size / dst_data_per_access, 1>{}([&](auto i) { p_src, src_coord);
auto scalar_id = make_zero_multi_index<nDim>();
scalar_id(vector_access_dim) = i * dst_data_per_access;
const index_t buffer_offset = i * dst_data_per_access; // store data from the long-vector buffer to dst
constexpr auto buff_off =
ThreadBufferDesc::CalculateOffset(to_multi_index(long_vector_data_begin_id));
constexpr auto buff_off = ThreadBufferDesc::CalculateOffset( // static_assert(buff_off == 0 || buff_off == 1 || buff_off == 2 || buff_off == 3,
to_multi_index(long_vector_data_begin_id)); // "");
thread_buff.s1(Number<buff_off>{}) = long_vector[buffer_offset]; thread_buff.s1(Number<buff_off>{}) = src_buff;
});
}); });
} }
...@@ -201,62 +171,19 @@ struct ThreadwiseGenericTensorSliceCopy_v5 ...@@ -201,62 +171,19 @@ struct ThreadwiseGenericTensorSliceCopy_v5
static_ford<decltype(long_vector_access_lengths), SrcDstDimAccessOrder>{}( static_ford<decltype(long_vector_access_lengths), SrcDstDimAccessOrder>{}(
[&](auto long_vector_access_id) { [&](auto long_vector_access_id) {
// data id w.r.t slicing-window constexpr auto long_vector_data_begin_id = long_vector_access_id.Modify(
auto long_vector_data_begin_id = to_multi_index(long_vector_access_id); Number<vector_access_dim>{},
long_vector_data_begin_id(vector_access_dim) = Number<long_vector_size * long_vector_access_id[vector_access_dim]>{});
long_vector_size * long_vector_access_id[vector_access_dim];
// buffer to hold a src long-vector
DstData long_vector[long_vector_size];
#if 1
// zero out buffer
static_for<0, long_vector_size, 1>{}([&](auto i) { long_vector[i] = 0; });
#endif
// load data from src to the long-vector buffer
static_for<0, long_vector_size / src_data_per_access, 1>{}([&](auto i) {
auto scalar_id = make_zero_multi_index<nDim>();
scalar_id(vector_access_dim) = i * src_data_per_access;
const index_t buffer_offset = i * src_data_per_access; constexpr auto buff_off =
ThreadBufferDesc::CalculateOffset(to_multi_index(long_vector_data_begin_id));
auto buff_off = auto src_buff = thread_buff.s1[Number<buff_off>{}];
ThreadBufferDesc::CalculateOffset(long_vector_data_begin_id + scalar_id);
// long_vector[buffer_offset] = thread_buff.s1[Number<buff_off>{}]; const auto dst_coord = mDstSliceOrigin + to_multi_index(long_vector_data_begin_id);
long_vector[buffer_offset] = thread_buff.n[buff_off];
});
// store data from the long-vector buffer to dst vector_data_store<DstData, DstDataPerWrite, long_vector_size>::run(
static_for<0, long_vector_size / dst_data_per_access, 1>{}([&](auto i) { p_dst, src_buff, dst_coord);
auto scalar_id = make_zero_multi_index<nDim>();
scalar_id(vector_access_dim) = i * dst_data_per_access;
const index_t buffer_offset = i * dst_data_per_access;
const auto dst_coord =
mDstSliceOrigin + (long_vector_data_begin_id + scalar_id);
// Check dst data's valid mapping situation, only check the first data in this
// dst
// vector. It's user's responsiblity to make sure all data in the dst vector
// has the valid/invalid mapping situation
transfer_data<DstData,
DstDataPerWrite,
AddressSpace::Vgpr,
DstAddressSpace,
DstInMemOp,
1,
DstDataStride>(long_vector,
buffer_offset,
true,
long_vector_size,
p_dst,
dst_coord.GetOffset(),
dst_coord.IsOffsetValidAssumingUpperIndexIsValid(),
DstDesc::GetElementSpace());
});
}); });
} }
...@@ -280,7 +207,7 @@ struct ThreadwiseGenericTensorSliceCopy_v5 ...@@ -280,7 +207,7 @@ struct ThreadwiseGenericTensorSliceCopy_v5
.Else([&](auto) { mDstSliceOrigin -= step_sizes; }); .Else([&](auto) { mDstSliceOrigin -= step_sizes; });
} }
float_vec8_t thread_buff; float_vec4_t thread_buff;
private: private:
SrcCoord mSrcSliceOrigin; SrcCoord mSrcSliceOrigin;
......
...@@ -32,8 +32,39 @@ union float_vec2_t ...@@ -32,8 +32,39 @@ union float_vec2_t
union float_vec4_t union float_vec4_t
{ {
Tuple<float, float, float, float> s1; Tuple<float, float, float, float> s1;
struct{
float e0, e1, e2, e3;
} ss1;
float4_t s4; float4_t s4;
float n[4];
__host__ __device__ constexpr float_vec4_t() {} __host__ __device__ constexpr float_vec4_t() {}
template<typename T, index_t i>
__host__ __device__ void set(const T val);
template<>
__host__ __device__ void set<float, 0>(const float val)
{
ss1.e0 = val;
}
template<>
__host__ __device__ void set<float, 1>(const float val)
{
ss1.e1 = val;
}
template<>
__host__ __device__ void set<float, 2>(const float val)
{
ss1.e2 = val;
}
template<>
__host__ __device__ void set<float, 3>(const float val)
{
ss1.e3 = val;
}
}; };
union float_vec8_t union float_vec8_t
......
...@@ -120,7 +120,7 @@ void device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -120,7 +120,7 @@ void device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 1 #elif 0
// cdata = 64, BlockSize = 256, 128x128x8 // cdata = 64, BlockSize = 256, 128x128x8
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
...@@ -183,8 +183,8 @@ void device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -183,8 +183,8 @@ void device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>; using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>; using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4; constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
#elif 0 #elif 0
......
...@@ -22,487 +22,20 @@ int main(int argc, char* argv[]) ...@@ -22,487 +22,20 @@ int main(int argc, char* argv[])
{ {
using namespace ck; using namespace ck;
#if 0 // 1x1, 56x56
// 1x1, 8x8
constexpr index_t N = 2;
constexpr index_t C = 24;
constexpr index_t HI = 8;
constexpr index_t WI = 8;
constexpr index_t K = 128;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 71x71
constexpr index_t N = 128;
constexpr index_t C = 192;
constexpr index_t HI = 71;
constexpr index_t WI = 71;
constexpr index_t K = 128;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 0
// 1x1, 8x8
constexpr index_t N = 128;
constexpr index_t C = 1536;
constexpr index_t HI = 8;
constexpr index_t WI = 8;
constexpr index_t K = 256;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 73x73
constexpr index_t N = 128;
constexpr index_t C = 160;
constexpr index_t HI = 73;
constexpr index_t WI = 73;
constexpr index_t K = 64;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 35x35
constexpr index_t N = 128;
constexpr index_t C = 96;
constexpr index_t HI = 35;
constexpr index_t WI = 35;
constexpr index_t K = 96;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 0
// 3x3, 71x71
constexpr index_t N = 128;
constexpr index_t C = 192;
constexpr index_t HI = 71;
constexpr index_t WI = 71;
constexpr index_t K = 192;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 0
// 7x1, 17x17
constexpr index_t N = 128;
constexpr index_t C = 128;
constexpr index_t HI = 17;
constexpr index_t WI = 17;
constexpr index_t K = 128;
constexpr index_t Y = 7;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<3, 0>;
using RightPads = Sequence<3, 0>;
#elif 1
// 1x7, 17x17
constexpr index_t N = 128;
constexpr index_t C = 128;
constexpr index_t HI = 17;
constexpr index_t WI = 17;
constexpr index_t K = 128;
constexpr index_t Y = 1;
constexpr index_t X = 7;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 3>;
using RightPads = Sequence<0, 3>;
#elif 0
// 3x3, 299x299 stride=2
constexpr index_t N = 128;
constexpr index_t C = 3;
constexpr index_t HI = 299;
constexpr index_t WI = 299;
constexpr index_t K = 32;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 147x147
constexpr index_t N = 128;
constexpr index_t C = 32;
constexpr index_t HI = 147;
constexpr index_t WI = 147;
constexpr index_t K = 64;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 0
// 3x3, 149x149
constexpr index_t N = 128;
constexpr index_t C = 32;
constexpr index_t HI = 149;
constexpr index_t WI = 149;
constexpr index_t K = 32;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 17x17, stride 2
constexpr index_t N = 128;
constexpr index_t C = 192;
constexpr index_t HI = 17;
constexpr index_t WI = 17;
constexpr index_t K = 192;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 35x35
constexpr index_t N = 128;
constexpr index_t C = 384;
constexpr index_t HI = 35;
constexpr index_t WI = 35;
constexpr index_t K = 96;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 35x35, stride 2
constexpr index_t N = 128;
constexpr index_t C = 288;
constexpr index_t HI = 35;
constexpr index_t WI = 35;
constexpr index_t K = 384;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x3, 8x8
constexpr index_t N = 128;
constexpr index_t C = 384;
constexpr index_t HI = 8;
constexpr index_t WI = 8;
constexpr index_t K = 448;
constexpr index_t Y = 1;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 1>;
using RightPads = Sequence<0, 1>;
#elif 0
// 3x1, 8x8
constexpr index_t N = 128;
constexpr index_t C = 448;
constexpr index_t HI = 8;
constexpr index_t WI = 8;
constexpr index_t K = 512;
constexpr index_t Y = 3;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 0>;
using RightPads = Sequence<1, 0>;
#elif 0
// 3x3, 147x147
constexpr index_t N = 128;
constexpr index_t C = 64;
constexpr index_t HI = 147;
constexpr index_t WI = 147;
constexpr index_t K = 96;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 7x1, 73x73
constexpr index_t N = 128;
constexpr index_t C = 64;
constexpr index_t HI = 73;
constexpr index_t WI = 73;
constexpr index_t K = 64;
constexpr index_t Y = 7;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<3, 0>;
using RightPads = Sequence<3, 0>;
#elif 0
// 3x3, 73x73
constexpr index_t N = 128;
constexpr index_t C = 64;
constexpr index_t HI = 73;
constexpr index_t WI = 73;
constexpr index_t K = 96;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 14x14, stride 2
constexpr index_t N = 128;
constexpr index_t C = 1024;
constexpr index_t HI = 14;
constexpr index_t WI = 14;
constexpr index_t K = 2048;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 14x14
constexpr index_t N = 128;
constexpr index_t C = 1024;
constexpr index_t HI = 14;
constexpr index_t WI = 14;
constexpr index_t K = 256;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 14x14, stride 2
constexpr index_t N = 128;
constexpr index_t C = 1024;
constexpr index_t HI = 14;
constexpr index_t WI = 14;
constexpr index_t K = 512;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 1
// 3x3, 28x28
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 128; constexpr index_t C = 128;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 128;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 0
// 3x3, 14x14
constexpr index_t N = 128;
constexpr index_t C = 256;
constexpr index_t HI = 14;
constexpr index_t WI = 14;
constexpr index_t K = 256;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 0
// 1x1, 56x56, stride 2
constexpr index_t N = 128;
constexpr index_t C = 256;
constexpr index_t HI = 56; constexpr index_t HI = 56;
constexpr index_t WI = 56; constexpr index_t WI = 56;
constexpr index_t K = 128; constexpr index_t K = 128;
constexpr index_t Y = 1; constexpr index_t Y = 1;
constexpr index_t X = 1; constexpr index_t X = 1;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 7x7, 230x230 stride=2
constexpr index_t N = 128;
constexpr index_t C = 3;
constexpr index_t HI = 230;
constexpr index_t WI = 230;
constexpr index_t K = 64;
constexpr index_t Y = 7;
constexpr index_t X = 7;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 28x28, stride = 2
constexpr index_t N = 128;
constexpr index_t C = 512;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 1024;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 28x28, stride 2
constexpr index_t N = 128;
constexpr index_t C = 512;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 256;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 7x7
constexpr index_t N = 128;
constexpr index_t C = 512;
constexpr index_t HI = 7;
constexpr index_t WI = 7;
constexpr index_t K = 2048;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 7x7
constexpr index_t N = 128;
constexpr index_t C = 512;
constexpr index_t HI = 7;
constexpr index_t WI = 7;
constexpr index_t K = 512;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 0
// 1x1, 56x56
constexpr index_t N = 128;
constexpr index_t C = 64;
constexpr index_t HI = 56;
constexpr index_t WI = 56;
constexpr index_t K = 64;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 56x56
constexpr index_t N = 128;
constexpr index_t C = 64;
constexpr index_t HI = 56;
constexpr index_t WI = 56;
constexpr index_t K = 64;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#endif
auto in_nchw_desc = make_native_tensor_descriptor_packed(Sequence<N, C, HI, WI>{}); auto in_nchw_desc = make_native_tensor_descriptor_packed(Sequence<N, C, HI, WI>{});
auto wei_kcyx_desc = make_native_tensor_descriptor_packed(Sequence<K, C, Y, X>{}); auto wei_kcyx_desc = make_native_tensor_descriptor_packed(Sequence<K, C, Y, X>{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment