Commit bbdb77e8 authored by Jing Zhang's avatar Jing Zhang
Browse files

static load

parent c1e24c09
...@@ -87,6 +87,12 @@ struct NativeTensorDescriptor ...@@ -87,6 +87,12 @@ struct NativeTensorDescriptor
return offset; return offset;
} }
template <typename Seq>
__host__ __device__ static constexpr index_t CalculateOffset(Seq)
{
return reduce_on_sequence(Seq{} * GetStrides(), math::plus<index_t>{}, Number<0>{});
}
__host__ __device__ static constexpr index_t CalculateOffsetDiff(const Index& idx_diff) __host__ __device__ static constexpr index_t CalculateOffsetDiff(const Index& idx_diff)
{ {
index_t offset_diff = 0; index_t offset_diff = 0;
......
...@@ -75,6 +75,41 @@ struct ThreadwiseGenericTensorSliceCopy_v5 ...@@ -75,6 +75,41 @@ struct ThreadwiseGenericTensorSliceCopy_v5
mDstSliceOrigin = dst_slice_origin; mDstSliceOrigin = dst_slice_origin;
} }
template <typename DstData, typename SrcData>
__device__ static DstData load_data(const SrcData* p_src, index_t src_offset)
{
return *reinterpret_cast<const DstData*>(&p_src[src_offset]);
}
template <typename DstData, typename SrcData>
__device__ static void store_data(const SrcData src_data, DstData* p_dst, index_t dst_offset)
{
*reinterpret_cast<SrcData*>(&p_dst[dst_offset]) = src_data;
}
template <typename SrcData, index_t SrcDataPerAccess, index_t VectorSize>
struct vector_data_load;
template <>
struct vector_data_load<float, 1, 1>
{
template <typename SrcCoord>
__device__ static float run(const float* p_src, const SrcCoord src_coord_begin)
{
constexpr auto vector_access_dim = Number<SrcDstVectorReadWriteDim>{};
auto scalar_id = make_zero_array<index_t, nDim>();
float r;
scalar_id(vector_access_dim) = 0;
auto src_coord = src_coord_begin + scalar_id;
r = load_data<float, float>(p_src, src_coord.GetOffset());
return r;
}
};
template <typename SrcData> template <typename SrcData>
__device__ void Load(const SrcData* p_src) __device__ void Load(const SrcData* p_src)
{ {
...@@ -90,12 +125,11 @@ struct ThreadwiseGenericTensorSliceCopy_v5 ...@@ -90,12 +125,11 @@ struct ThreadwiseGenericTensorSliceCopy_v5
constexpr auto long_vector_access_lengths = SliceLengths::Modify( constexpr auto long_vector_access_lengths = SliceLengths::Modify(
vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size); vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size);
ford<decltype(long_vector_access_lengths), SrcDstDimAccessOrder>{}( static_ford<decltype(long_vector_access_lengths), SrcDstDimAccessOrder>{}(
[&](auto long_vector_access_id) { [&](auto long_vector_access_id) {
// data id w.r.t slicing-window constexpr auto long_vector_data_begin_id = long_vector_access_id.Modify(
auto long_vector_data_begin_id = long_vector_access_id; Number<vector_access_dim>{},
long_vector_data_begin_id(vector_access_dim) = Number<long_vector_size * long_vector_access_id[vector_access_dim]>{});
long_vector_size * long_vector_access_id[vector_access_dim];
// buffer to hold a src long-vector // buffer to hold a src long-vector
SrcData long_vector[long_vector_size]; SrcData long_vector[long_vector_size];
...@@ -113,7 +147,7 @@ struct ThreadwiseGenericTensorSliceCopy_v5 ...@@ -113,7 +147,7 @@ struct ThreadwiseGenericTensorSliceCopy_v5
const index_t buffer_offset = i * src_data_per_access; const index_t buffer_offset = i * src_data_per_access;
const auto src_coord = const auto src_coord =
mSrcSliceOrigin + (long_vector_data_begin_id + scalar_id); mSrcSliceOrigin + (to_multi_index(long_vector_data_begin_id) + scalar_id);
// Check src data's valid mapping situation, only check the first data in this // Check src data's valid mapping situation, only check the first data in this
// src // src
...@@ -142,12 +176,10 @@ struct ThreadwiseGenericTensorSliceCopy_v5 ...@@ -142,12 +176,10 @@ struct ThreadwiseGenericTensorSliceCopy_v5
const index_t buffer_offset = i * dst_data_per_access; const index_t buffer_offset = i * dst_data_per_access;
const auto dst_coord = constexpr auto buff_off = ThreadBufferDesc::CalculateOffset(
mDstSliceOrigin + (long_vector_data_begin_id + scalar_id); to_multi_index(long_vector_data_begin_id));
auto buff_off = thread_buff.s1(Number<buff_off>{}) = long_vector[buffer_offset];
ThreadBufferDesc::CalculateOffset(long_vector_data_begin_id + scalar_id);
thread_buff[buff_off] = long_vector[buffer_offset];
}); });
}); });
} }
...@@ -167,10 +199,10 @@ struct ThreadwiseGenericTensorSliceCopy_v5 ...@@ -167,10 +199,10 @@ struct ThreadwiseGenericTensorSliceCopy_v5
constexpr auto long_vector_access_lengths = SliceLengths::Modify( constexpr auto long_vector_access_lengths = SliceLengths::Modify(
vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size); vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size);
ford<decltype(long_vector_access_lengths), SrcDstDimAccessOrder>{}( static_ford<decltype(long_vector_access_lengths), SrcDstDimAccessOrder>{}(
[&](auto long_vector_access_id) { [&](auto long_vector_access_id) {
// data id w.r.t slicing-window // data id w.r.t slicing-window
auto long_vector_data_begin_id = long_vector_access_id; auto long_vector_data_begin_id = to_multi_index(long_vector_access_id);
long_vector_data_begin_id(vector_access_dim) = long_vector_data_begin_id(vector_access_dim) =
long_vector_size * long_vector_access_id[vector_access_dim]; long_vector_size * long_vector_access_id[vector_access_dim];
...@@ -192,7 +224,8 @@ struct ThreadwiseGenericTensorSliceCopy_v5 ...@@ -192,7 +224,8 @@ struct ThreadwiseGenericTensorSliceCopy_v5
auto buff_off = auto buff_off =
ThreadBufferDesc::CalculateOffset(long_vector_data_begin_id + scalar_id); ThreadBufferDesc::CalculateOffset(long_vector_data_begin_id + scalar_id);
long_vector[buffer_offset] = thread_buff[buff_off]; // long_vector[buffer_offset] = thread_buff.s1[Number<buff_off>{}];
long_vector[buffer_offset] = thread_buff.n[buff_off];
}); });
// store data from the long-vector buffer to dst // store data from the long-vector buffer to dst
...@@ -247,7 +280,7 @@ struct ThreadwiseGenericTensorSliceCopy_v5 ...@@ -247,7 +280,7 @@ struct ThreadwiseGenericTensorSliceCopy_v5
.Else([&](auto) { mDstSliceOrigin -= step_sizes; }); .Else([&](auto) { mDstSliceOrigin -= step_sizes; });
} }
float thread_buff[8]; float_vec8_t thread_buff;
private: private:
SrcCoord mSrcSliceOrigin; SrcCoord mSrcSliceOrigin;
......
...@@ -7,6 +7,7 @@ namespace ck { ...@@ -7,6 +7,7 @@ namespace ck {
// float // float
typedef float float2_t __attribute__((ext_vector_type(2))); typedef float float2_t __attribute__((ext_vector_type(2)));
typedef float float4_t __attribute__((ext_vector_type(4))); typedef float float4_t __attribute__((ext_vector_type(4)));
typedef float float8_t __attribute__((ext_vector_type(8)));
typedef float float16_t __attribute__((ext_vector_type(16))); typedef float float16_t __attribute__((ext_vector_type(16)));
typedef float float32_t __attribute__((ext_vector_type(32))); typedef float float32_t __attribute__((ext_vector_type(32)));
...@@ -21,6 +22,51 @@ typedef ushort ushort2_t __attribute__((ext_vector_type(2))); ...@@ -21,6 +22,51 @@ typedef ushort ushort2_t __attribute__((ext_vector_type(2)));
typedef ushort ushort4_t __attribute__((ext_vector_type(4))); typedef ushort ushort4_t __attribute__((ext_vector_type(4)));
typedef ushort ushort8_t __attribute__((ext_vector_type(8))); typedef ushort ushort8_t __attribute__((ext_vector_type(8)));
union float_vec2_t
{
Tuple<float, float> s1;
float2_t s2;
__host__ __device__ constexpr float_vec2_t() {}
};
union float_vec4_t
{
Tuple<float, float, float, float> s1;
float4_t s4;
__host__ __device__ constexpr float_vec4_t() {}
};
union float_vec8_t
{
Tuple<float, float, float, float, float, float, float, float> s1;
Tuple<float_vec2_t, float_vec2_t, float_vec2_t, float_vec2_t> s2;
struct{
float_vec4_t e0;
float_vec4_t e1;
} ss4;
Tuple<float_vec4_t, float_vec4_t> s4;
float8_t s8;
float n[8];
__host__ __device__ constexpr float_vec8_t() {}
template<typename T, index_t i>
__host__ __device__ void set(const T val);
template<>
__host__ __device__ void set<float_vec4_t, 0>(const float_vec4_t val)
{
ss4.e0 = val;
}
template<>
__host__ __device__ void set<float_vec4_t, 1>(const float_vec4_t val)
{
ss4.e1 = val;
}
};
struct c_vec32_4_t struct c_vec32_4_t
{ {
union VecType union VecType
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment