Commit d5c856bc authored by Jing Zhang's avatar Jing Zhang
Browse files

add At into vector type

parent 5d3b9b73
......@@ -91,11 +91,11 @@ struct ThreadwiseGenericTensorSliceCopy_v5
*reinterpret_cast<SrcData*>(&p_dst[dst_offset]) = src_data;
}
template <typename SrcData, index_t SrcDataPerAccess, index_t VectorSize>
template <typename SrcData, index_t SrcDataPerAccess>
struct vector_data_load;
template <>
struct vector_data_load<float, 1, 1>
struct vector_data_load<float, 1>
{
template <typename SrcCoord>
__device__ static float run(const float* p_src, const SrcCoord src_coord_begin)
......@@ -106,11 +106,11 @@ struct ThreadwiseGenericTensorSliceCopy_v5
}
};
template <typename DstData, index_t DstDataPerAccess, index_t VectorSize>
template <typename DstData, index_t DstDataPerAccess>
struct vector_data_store;
template <>
struct vector_data_store<float, 1, 1>
struct vector_data_store<float, 1>
{
template <typename DstCoord>
__device__ static void
......@@ -143,14 +143,13 @@ struct ThreadwiseGenericTensorSliceCopy_v5
// load data from src to the long-vector buffer
const auto src_coord = mSrcSliceOrigin + to_multi_index(long_vector_data_begin_id);
auto src_buff = vector_data_load<SrcData, SrcDataPerRead, long_vector_size>::run(
p_src, src_coord);
auto src_buff = vector_data_load<SrcData, SrcDataPerRead>::run(p_src, src_coord);
// store data from the long-vector buffer to dst
constexpr auto buff_off =
ThreadBufferDesc::CalculateOffset(to_multi_index(long_vector_data_begin_id));
thread_buff.s1(Number<buff_off>{}) = src_buff;
thread_buff.template At<SrcDataPerRead>()(Number<buff_off>{}) = src_buff;
});
}
......@@ -177,12 +176,11 @@ struct ThreadwiseGenericTensorSliceCopy_v5
constexpr auto buff_off =
ThreadBufferDesc::CalculateOffset(to_multi_index(long_vector_data_begin_id));
auto src_buff = thread_buff.s1[Number<buff_off>{}];
auto src_buff = thread_buff.template At<DstDataPerWrite>()[Number<buff_off>{}];
const auto dst_coord = mDstSliceOrigin + to_multi_index(long_vector_data_begin_id);
vector_data_store<DstData, DstDataPerWrite, long_vector_size>::run(
p_dst, src_buff, dst_coord);
vector_data_store<DstData, DstDataPerWrite>::run(p_dst, src_buff, dst_coord);
});
}
......
......@@ -34,6 +34,21 @@ union float_vec4_t
StaticallyIndexedArray<float, 4> s1;
float4_t s4;
__host__ __device__ constexpr float_vec4_t() {s4 = {0, 0, 0, 0};}
template<index_t vs>
__host__ __device__ auto& At();
template<>
__host__ __device__ auto& At<1>()
{
return s1;
}
template<>
__host__ __device__ auto& At<4>()
{
return s4;
}
};
union float_vec8_t
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment