Commit d5c856bc authored by Jing Zhang's avatar Jing Zhang
Browse files

add At into vector type

parent 5d3b9b73
...@@ -91,11 +91,11 @@ struct ThreadwiseGenericTensorSliceCopy_v5 ...@@ -91,11 +91,11 @@ struct ThreadwiseGenericTensorSliceCopy_v5
*reinterpret_cast<SrcData*>(&p_dst[dst_offset]) = src_data; *reinterpret_cast<SrcData*>(&p_dst[dst_offset]) = src_data;
} }
template <typename SrcData, index_t SrcDataPerAccess, index_t VectorSize> template <typename SrcData, index_t SrcDataPerAccess>
struct vector_data_load; struct vector_data_load;
template <> template <>
struct vector_data_load<float, 1, 1> struct vector_data_load<float, 1>
{ {
template <typename SrcCoord> template <typename SrcCoord>
__device__ static float run(const float* p_src, const SrcCoord src_coord_begin) __device__ static float run(const float* p_src, const SrcCoord src_coord_begin)
...@@ -106,11 +106,11 @@ struct ThreadwiseGenericTensorSliceCopy_v5 ...@@ -106,11 +106,11 @@ struct ThreadwiseGenericTensorSliceCopy_v5
} }
}; };
template <typename DstData, index_t DstDataPerAccess, index_t VectorSize> template <typename DstData, index_t DstDataPerAccess>
struct vector_data_store; struct vector_data_store;
template <> template <>
struct vector_data_store<float, 1, 1> struct vector_data_store<float, 1>
{ {
template <typename DstCoord> template <typename DstCoord>
__device__ static void __device__ static void
...@@ -143,14 +143,13 @@ struct ThreadwiseGenericTensorSliceCopy_v5 ...@@ -143,14 +143,13 @@ struct ThreadwiseGenericTensorSliceCopy_v5
// load data from src to the long-vector buffer // load data from src to the long-vector buffer
const auto src_coord = mSrcSliceOrigin + to_multi_index(long_vector_data_begin_id); const auto src_coord = mSrcSliceOrigin + to_multi_index(long_vector_data_begin_id);
auto src_buff = vector_data_load<SrcData, SrcDataPerRead, long_vector_size>::run( auto src_buff = vector_data_load<SrcData, SrcDataPerRead>::run(p_src, src_coord);
p_src, src_coord);
// store data from the long-vector buffer to dst // store data from the long-vector buffer to dst
constexpr auto buff_off = constexpr auto buff_off =
ThreadBufferDesc::CalculateOffset(to_multi_index(long_vector_data_begin_id)); ThreadBufferDesc::CalculateOffset(to_multi_index(long_vector_data_begin_id));
thread_buff.s1(Number<buff_off>{}) = src_buff; thread_buff.template At<SrcDataPerRead>()(Number<buff_off>{}) = src_buff;
}); });
} }
...@@ -177,12 +176,11 @@ struct ThreadwiseGenericTensorSliceCopy_v5 ...@@ -177,12 +176,11 @@ struct ThreadwiseGenericTensorSliceCopy_v5
constexpr auto buff_off = constexpr auto buff_off =
ThreadBufferDesc::CalculateOffset(to_multi_index(long_vector_data_begin_id)); ThreadBufferDesc::CalculateOffset(to_multi_index(long_vector_data_begin_id));
auto src_buff = thread_buff.s1[Number<buff_off>{}]; auto src_buff = thread_buff.template At<DstDataPerWrite>()[Number<buff_off>{}];
const auto dst_coord = mDstSliceOrigin + to_multi_index(long_vector_data_begin_id); const auto dst_coord = mDstSliceOrigin + to_multi_index(long_vector_data_begin_id);
vector_data_store<DstData, DstDataPerWrite, long_vector_size>::run( vector_data_store<DstData, DstDataPerWrite>::run(p_dst, src_buff, dst_coord);
p_dst, src_buff, dst_coord);
}); });
} }
......
...@@ -34,6 +34,21 @@ union float_vec4_t ...@@ -34,6 +34,21 @@ union float_vec4_t
StaticallyIndexedArray<float, 4> s1; StaticallyIndexedArray<float, 4> s1;
float4_t s4; float4_t s4;
__host__ __device__ constexpr float_vec4_t() {s4 = {0, 0, 0, 0};} __host__ __device__ constexpr float_vec4_t() {s4 = {0, 0, 0, 0};}
template<index_t vs>
__host__ __device__ auto& At();
template<>
__host__ __device__ auto& At<1>()
{
return s1;
}
template<>
__host__ __device__ auto& At<4>()
{
return s4;
}
}; };
union float_vec8_t union float_vec8_t
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment