Commit 7bbcd0fe authored by Jing Zhang's avatar Jing Zhang
Browse files

vector type demo for fp32

parent bbdb77e8
......@@ -96,17 +96,23 @@ struct ThreadwiseGenericTensorSliceCopy_v5
template <typename SrcCoord>
__device__ static float run(const float* p_src, const SrcCoord src_coord_begin)
{
constexpr auto vector_access_dim = Number<SrcDstVectorReadWriteDim>{};
auto scalar_id = make_zero_array<index_t, nDim>();
float r;
r = load_data<float, float>(p_src, src_coord_begin.GetOffset());
return r;
}
};
scalar_id(vector_access_dim) = 0;
auto src_coord = src_coord_begin + scalar_id;
r = load_data<float, float>(p_src, src_coord.GetOffset());
template <typename DstData, index_t DstDataPerAccess, index_t VectorSize>
struct vector_data_store;
return r;
template <>
struct vector_data_store<float, 1, 1>
{
template <typename DstCoord>
__device__ static void
run(float* p_dst, const float src_data, const DstCoord dst_coord_begin)
{
store_data<float, float>(src_data, p_dst, dst_coord_begin.GetOffset());
}
};
......@@ -131,56 +137,20 @@ struct ThreadwiseGenericTensorSliceCopy_v5
Number<vector_access_dim>{},
Number<long_vector_size * long_vector_access_id[vector_access_dim]>{});
// buffer to hold a src long-vector
SrcData long_vector[long_vector_size];
#if 1
// zero out buffer
static_for<0, long_vector_size, 1>{}([&](auto i) { long_vector[i] = 0; });
#endif
// load data from src to the long-vector buffer
static_for<0, long_vector_size / src_data_per_access, 1>{}([&](auto i) {
auto scalar_id = make_zero_multi_index<nDim>();
scalar_id(vector_access_dim) = i * src_data_per_access;
const index_t buffer_offset = i * src_data_per_access;
const auto src_coord =
mSrcSliceOrigin + (to_multi_index(long_vector_data_begin_id) + scalar_id);
// Check src data's valid mapping situation, only check the first data in this
// src
// vector. It's user's responsiblity to make sure all data in the src vector
// has the valid/invalid mapping situation
transfer_data<SrcData,
SrcDataPerRead,
SrcAddressSpace,
AddressSpace::Vgpr,
InMemoryDataOperation::Set,
SrcDataStride,
1>(p_src,
src_coord.GetOffset(),
src_coord.IsOffsetValidAssumingUpperIndexIsValid(),
SrcDesc::GetElementSpace(),
long_vector,
buffer_offset,
true,
long_vector_size);
});
const auto src_coord = mSrcSliceOrigin + to_multi_index(long_vector_data_begin_id);
// store data from the long-vector buffer to dst
static_for<0, long_vector_size / dst_data_per_access, 1>{}([&](auto i) {
auto scalar_id = make_zero_multi_index<nDim>();
scalar_id(vector_access_dim) = i * dst_data_per_access;
auto src_buff = vector_data_load<SrcData, SrcDataPerRead, long_vector_size>::run(
p_src, src_coord);
const index_t buffer_offset = i * dst_data_per_access;
// store data from the long-vector buffer to dst
constexpr auto buff_off =
ThreadBufferDesc::CalculateOffset(to_multi_index(long_vector_data_begin_id));
constexpr auto buff_off = ThreadBufferDesc::CalculateOffset(
to_multi_index(long_vector_data_begin_id));
// static_assert(buff_off == 0 || buff_off == 1 || buff_off == 2 || buff_off == 3,
// "");
thread_buff.s1(Number<buff_off>{}) = long_vector[buffer_offset];
});
thread_buff.s1(Number<buff_off>{}) = src_buff;
});
}
......@@ -201,62 +171,19 @@ struct ThreadwiseGenericTensorSliceCopy_v5
static_ford<decltype(long_vector_access_lengths), SrcDstDimAccessOrder>{}(
[&](auto long_vector_access_id) {
// data id w.r.t slicing-window
auto long_vector_data_begin_id = to_multi_index(long_vector_access_id);
long_vector_data_begin_id(vector_access_dim) =
long_vector_size * long_vector_access_id[vector_access_dim];
// buffer to hold a src long-vector
DstData long_vector[long_vector_size];
#if 1
// zero out buffer
static_for<0, long_vector_size, 1>{}([&](auto i) { long_vector[i] = 0; });
#endif
// load data from src to the long-vector buffer
static_for<0, long_vector_size / src_data_per_access, 1>{}([&](auto i) {
auto scalar_id = make_zero_multi_index<nDim>();
scalar_id(vector_access_dim) = i * src_data_per_access;
constexpr auto long_vector_data_begin_id = long_vector_access_id.Modify(
Number<vector_access_dim>{},
Number<long_vector_size * long_vector_access_id[vector_access_dim]>{});
const index_t buffer_offset = i * src_data_per_access;
constexpr auto buff_off =
ThreadBufferDesc::CalculateOffset(to_multi_index(long_vector_data_begin_id));
auto buff_off =
ThreadBufferDesc::CalculateOffset(long_vector_data_begin_id + scalar_id);
auto src_buff = thread_buff.s1[Number<buff_off>{}];
// long_vector[buffer_offset] = thread_buff.s1[Number<buff_off>{}];
long_vector[buffer_offset] = thread_buff.n[buff_off];
});
const auto dst_coord = mDstSliceOrigin + to_multi_index(long_vector_data_begin_id);
// store data from the long-vector buffer to dst
static_for<0, long_vector_size / dst_data_per_access, 1>{}([&](auto i) {
auto scalar_id = make_zero_multi_index<nDim>();
scalar_id(vector_access_dim) = i * dst_data_per_access;
const index_t buffer_offset = i * dst_data_per_access;
const auto dst_coord =
mDstSliceOrigin + (long_vector_data_begin_id + scalar_id);
// Check dst data's valid mapping situation, only check the first data in this
// dst
// vector. It's user's responsiblity to make sure all data in the dst vector
// has the valid/invalid mapping situation
transfer_data<DstData,
DstDataPerWrite,
AddressSpace::Vgpr,
DstAddressSpace,
DstInMemOp,
1,
DstDataStride>(long_vector,
buffer_offset,
true,
long_vector_size,
p_dst,
dst_coord.GetOffset(),
dst_coord.IsOffsetValidAssumingUpperIndexIsValid(),
DstDesc::GetElementSpace());
});
vector_data_store<DstData, DstDataPerWrite, long_vector_size>::run(
p_dst, src_buff, dst_coord);
});
}
......@@ -280,7 +207,7 @@ struct ThreadwiseGenericTensorSliceCopy_v5
.Else([&](auto) { mDstSliceOrigin -= step_sizes; });
}
float_vec8_t thread_buff;
float_vec4_t thread_buff;
private:
SrcCoord mSrcSliceOrigin;
......
......@@ -32,8 +32,39 @@ union float_vec2_t
union float_vec4_t
{
Tuple<float, float, float, float> s1;
struct{
float e0, e1, e2, e3;
} ss1;
float4_t s4;
float n[4];
__host__ __device__ constexpr float_vec4_t() {}
template<typename T, index_t i>
__host__ __device__ void set(const T val);
template<>
__host__ __device__ void set<float, 0>(const float val)
{
ss1.e0 = val;
}
template<>
__host__ __device__ void set<float, 1>(const float val)
{
ss1.e1 = val;
}
template<>
__host__ __device__ void set<float, 2>(const float val)
{
ss1.e2 = val;
}
template<>
__host__ __device__ void set<float, 3>(const float val)
{
ss1.e3 = val;
}
};
union float_vec8_t
......
......@@ -120,7 +120,7 @@ void device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 1
#elif 0
// cdata = 64, BlockSize = 256, 128x128x8
constexpr index_t BlockSize = 256;
......@@ -183,8 +183,8 @@ void device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
#elif 0
......
......@@ -22,487 +22,20 @@ int main(int argc, char* argv[])
{
using namespace ck;
#if 0
// 1x1, 8x8
constexpr index_t N = 2;
constexpr index_t C = 24;
constexpr index_t HI = 8;
constexpr index_t WI = 8;
constexpr index_t K = 128;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 71x71
constexpr index_t N = 128;
constexpr index_t C = 192;
constexpr index_t HI = 71;
constexpr index_t WI = 71;
constexpr index_t K = 128;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 0
// 1x1, 8x8
constexpr index_t N = 128;
constexpr index_t C = 1536;
constexpr index_t HI = 8;
constexpr index_t WI = 8;
constexpr index_t K = 256;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 73x73
constexpr index_t N = 128;
constexpr index_t C = 160;
constexpr index_t HI = 73;
constexpr index_t WI = 73;
constexpr index_t K = 64;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 35x35
constexpr index_t N = 128;
constexpr index_t C = 96;
constexpr index_t HI = 35;
constexpr index_t WI = 35;
constexpr index_t K = 96;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 0
// 3x3, 71x71
constexpr index_t N = 128;
constexpr index_t C = 192;
constexpr index_t HI = 71;
constexpr index_t WI = 71;
constexpr index_t K = 192;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 0
// 7x1, 17x17
constexpr index_t N = 128;
constexpr index_t C = 128;
constexpr index_t HI = 17;
constexpr index_t WI = 17;
constexpr index_t K = 128;
constexpr index_t Y = 7;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<3, 0>;
using RightPads = Sequence<3, 0>;
#elif 1
// 1x7, 17x17
constexpr index_t N = 128;
constexpr index_t C = 128;
constexpr index_t HI = 17;
constexpr index_t WI = 17;
constexpr index_t K = 128;
constexpr index_t Y = 1;
constexpr index_t X = 7;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 3>;
using RightPads = Sequence<0, 3>;
#elif 0
// 3x3, 299x299 stride=2
constexpr index_t N = 128;
constexpr index_t C = 3;
constexpr index_t HI = 299;
constexpr index_t WI = 299;
constexpr index_t K = 32;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 147x147
constexpr index_t N = 128;
constexpr index_t C = 32;
constexpr index_t HI = 147;
constexpr index_t WI = 147;
constexpr index_t K = 64;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 0
// 3x3, 149x149
constexpr index_t N = 128;
constexpr index_t C = 32;
constexpr index_t HI = 149;
constexpr index_t WI = 149;
constexpr index_t K = 32;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 17x17, stride 2
constexpr index_t N = 128;
constexpr index_t C = 192;
constexpr index_t HI = 17;
constexpr index_t WI = 17;
constexpr index_t K = 192;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 35x35
constexpr index_t N = 128;
constexpr index_t C = 384;
constexpr index_t HI = 35;
constexpr index_t WI = 35;
constexpr index_t K = 96;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 35x35, stride 2
constexpr index_t N = 128;
constexpr index_t C = 288;
constexpr index_t HI = 35;
constexpr index_t WI = 35;
constexpr index_t K = 384;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x3, 8x8
constexpr index_t N = 128;
constexpr index_t C = 384;
constexpr index_t HI = 8;
constexpr index_t WI = 8;
constexpr index_t K = 448;
constexpr index_t Y = 1;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 1>;
using RightPads = Sequence<0, 1>;
#elif 0
// 3x1, 8x8
constexpr index_t N = 128;
constexpr index_t C = 448;
constexpr index_t HI = 8;
constexpr index_t WI = 8;
constexpr index_t K = 512;
constexpr index_t Y = 3;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 0>;
using RightPads = Sequence<1, 0>;
#elif 0
// 3x3, 147x147
constexpr index_t N = 128;
constexpr index_t C = 64;
constexpr index_t HI = 147;
constexpr index_t WI = 147;
constexpr index_t K = 96;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 7x1, 73x73
constexpr index_t N = 128;
constexpr index_t C = 64;
constexpr index_t HI = 73;
constexpr index_t WI = 73;
constexpr index_t K = 64;
constexpr index_t Y = 7;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<3, 0>;
using RightPads = Sequence<3, 0>;
#elif 0
// 3x3, 73x73
constexpr index_t N = 128;
constexpr index_t C = 64;
constexpr index_t HI = 73;
constexpr index_t WI = 73;
constexpr index_t K = 96;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 14x14, stride 2
constexpr index_t N = 128;
constexpr index_t C = 1024;
constexpr index_t HI = 14;
constexpr index_t WI = 14;
constexpr index_t K = 2048;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 14x14
constexpr index_t N = 128;
constexpr index_t C = 1024;
constexpr index_t HI = 14;
constexpr index_t WI = 14;
constexpr index_t K = 256;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 14x14, stride 2
constexpr index_t N = 128;
constexpr index_t C = 1024;
constexpr index_t HI = 14;
constexpr index_t WI = 14;
constexpr index_t K = 512;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 1
// 3x3, 28x28
// 1x1, 56x56
constexpr index_t N = 128;
constexpr index_t C = 128;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 128;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 0
// 3x3, 14x14
constexpr index_t N = 128;
constexpr index_t C = 256;
constexpr index_t HI = 14;
constexpr index_t WI = 14;
constexpr index_t K = 256;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 0
// 1x1, 56x56, stride 2
constexpr index_t N = 128;
constexpr index_t C = 256;
constexpr index_t HI = 56;
constexpr index_t WI = 56;
constexpr index_t K = 128;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 7x7, 230x230 stride=2
constexpr index_t N = 128;
constexpr index_t C = 3;
constexpr index_t HI = 230;
constexpr index_t WI = 230;
constexpr index_t K = 64;
constexpr index_t Y = 7;
constexpr index_t X = 7;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 28x28, stride = 2
constexpr index_t N = 128;
constexpr index_t C = 512;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 1024;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 28x28, stride 2
constexpr index_t N = 128;
constexpr index_t C = 512;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 256;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 7x7
constexpr index_t N = 128;
constexpr index_t C = 512;
constexpr index_t HI = 7;
constexpr index_t WI = 7;
constexpr index_t K = 2048;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 7x7
constexpr index_t N = 128;
constexpr index_t C = 512;
constexpr index_t HI = 7;
constexpr index_t WI = 7;
constexpr index_t K = 512;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 0
// 1x1, 56x56
constexpr index_t N = 128;
constexpr index_t C = 64;
constexpr index_t HI = 56;
constexpr index_t WI = 56;
constexpr index_t K = 64;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 56x56
constexpr index_t N = 128;
constexpr index_t C = 64;
constexpr index_t HI = 56;
constexpr index_t WI = 56;
constexpr index_t K = 64;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#endif
auto in_nchw_desc = make_native_tensor_descriptor_packed(Sequence<N, C, HI, WI>{});
auto wei_kcyx_desc = make_native_tensor_descriptor_packed(Sequence<K, C, Y, X>{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment