"library/src/host_tensor/device.cpp" did not exist on "d626dccc952b90397baa60fd3633a3504e93f92a"
Commit 961556eb authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed thread_buff r/w

parent 0486af77
......@@ -122,7 +122,10 @@ struct ThreadwiseGenericTensorSliceCopy_v5
__device__ static auto buffer_vector_load(const SrcData* p_src, const SrcCoord src_coord_begin)
{
auto src_offset = src_coord_begin.GetOffset();
return amd_buffer_load<SrcData, SrcDataPerAccess>(p_src, src_offset, true, SrcDataRange);
auto r = GetRegBuffer<SrcData, SrcDataPerAccess>();
r.GetVector(Number<SrcDataPerAccess>{})(Number<0>{}) =
amd_buffer_load<SrcData, SrcDataPerAccess>(p_src, src_offset, true, SrcDataRange);
return r;
}
template <typename DstData, index_t DstDataPerAccess>
......@@ -187,12 +190,17 @@ struct ThreadwiseGenericTensorSliceCopy_v5
auto src_buff = buffer_vector_load<SrcDataPerRead, SrcDesc::GetElementSpace()>(
p_src, src_coord);
// store data from the long-vector buffer to dst
constexpr auto buff_off =
ThreadBufferDesc::CalculateOffset(to_multi_index(long_vector_data_begin_id)) /
long_vector_size;
static_for<0, SrcDataPerRead, 1>{}([&](auto i) {
constexpr auto vector_id = long_vector_data_begin_id.Modify(
Number<vector_access_dim>{}, long_vector_access_id[vector_access_dim] + i);
thread_buff.GetVector(Number<SrcDataPerRead>{})(Number<buff_off>{}) = src_buff;
// store data from the long-vector buffer to dst
constexpr auto buff_off =
ThreadBufferDesc::CalculateOffset(to_multi_index(vector_id));
thread_buff.GetVector(Number<1>{})(Number<buff_off>{}) =
src_buff.GetVector(Number<1>{})(Number<i>{});
});
});
}
......@@ -216,16 +224,23 @@ struct ThreadwiseGenericTensorSliceCopy_v5
Number<vector_access_dim>{},
Number<long_vector_size * long_vector_access_id[vector_access_dim]>{});
constexpr auto buff_off =
ThreadBufferDesc::CalculateOffset(to_multi_index(long_vector_data_begin_id)) /
long_vector_size;
auto src_buff = GetRegBuffer<DstData, DstDataPerWrite>();
static_for<0, DstDataPerWrite, 1>{}([&](auto i) {
constexpr auto vector_id = long_vector_data_begin_id.Modify(
Number<vector_access_dim>{}, long_vector_access_id[vector_access_dim] + i);
constexpr auto buff_off =
ThreadBufferDesc::CalculateOffset(to_multi_index(vector_id));
auto src_buff =
thread_buff.GetVector(Number<DstDataPerWrite>{})[Number<buff_off>{}];
src_buff.GetVector(Number<1>{})(Number<i>{}) =
thread_buff.GetVector(Number<1>{})[Number<buff_off>{}];
});
const auto dst_coord = mDstSliceOrigin + to_multi_index(long_vector_data_begin_id);
vector_data_store<DstData, DstDataPerWrite>::run(p_dst, src_buff, dst_coord);
vector_data_store<DstData, DstDataPerWrite>::run(
p_dst, src_buff.GetVector(Number<DstDataPerWrite>{})[Number<0>{}], dst_coord);
});
}
......
......@@ -24,11 +24,41 @@ typedef ushort ushort2_t __attribute__((ext_vector_type(2)));
typedef ushort ushort4_t __attribute__((ext_vector_type(4)));
typedef ushort ushort8_t __attribute__((ext_vector_type(8)));
union float_vec1_t
{
StaticallyIndexedArray<float, 1> s1;
__host__ __device__ constexpr float_vec1_t() { s1(Number<0>{}) = 0; }
template <index_t vs>
__host__ __device__ auto& GetVector(Number<vs>);
template <>
__host__ __device__ auto& GetVector(Number<1>)
{
return s1;
}
};
union float_vec2_t
{
StaticallyIndexedArray<float, 2> s1;
StaticallyIndexedArray<float2_t, 1> s2;
__host__ __device__ constexpr float_vec2_t() { s2(Number<0>{}) = 0; }
template <index_t vs>
__host__ __device__ auto& GetVector(Number<vs>);
template <>
__host__ __device__ auto& GetVector(Number<1>)
{
return s1;
}
template <>
__host__ __device__ auto& GetVector(Number<2>)
{
return s2;
}
};
union float_vec4_t
......@@ -248,6 +278,18 @@ union float_vec128_t
template <typename T, index_t BufferSize>
constexpr auto GetRegBuffer();
template <>
constexpr auto GetRegBuffer<float, 1>()
{
return float_vec1_t{};
}
template <>
constexpr auto GetRegBuffer<float, 2>()
{
return float_vec2_t{};
}
template <>
constexpr auto GetRegBuffer<float, 4>()
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment