Commit 8d10b4d6 authored by Jianfeng yan's avatar Jianfeng yan
Browse files

add DynamicBuffer::Transfer, but Add is not tested

parent 260f0e93
...@@ -92,7 +92,8 @@ struct ThreadwiseTensorSliceTransfer_v1r3_using_space_filling_curve ...@@ -92,7 +92,8 @@ struct ThreadwiseTensorSliceTransfer_v1r3_using_space_filling_curve
remove_cv_t<decltype(dst_scalar_per_access)>>; remove_cv_t<decltype(dst_scalar_per_access)>>;
// TODO: Use SpaceFillingCurve::ScalarsPerAccess instread of DstScalarPerVector? // TODO: Use SpaceFillingCurve::ScalarsPerAccess instread of DstScalarPerVector?
static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector, "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector"); static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector,
"wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector");
typename vector_type_maker<DstData, DstScalarPerVector>::type dst_vector; typename vector_type_maker<DstData, DstScalarPerVector>::type dst_vector;
using dst_vector_t = typename vector_type_maker<DstData, DstScalarPerVector>::type::type; using dst_vector_t = typename vector_type_maker<DstData, DstScalarPerVector>::type::type;
...@@ -101,14 +102,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3_using_space_filling_curve ...@@ -101,14 +102,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3_using_space_filling_curve
static_for<0, num_accesses, 1>{}([&](auto idx_1d) { static_for<0, num_accesses, 1>{}([&](auto idx_1d) {
constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d); constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
// constexpr auto all_indices = SpaceFillingCurve::GetIndices(idx_1d);
// copy data from src_buf into dst_vector // copy data from src_buf into dst_vector
// TODO: It's a hack here to use \p dst_scalar_step_in_vector. Use SpaceFillingCurve?
static_for<0, DstScalarPerVector, 1>{}([&](auto i) { static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset = src_desc.CalculateOffset( constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
// constexpr index_t src_offset = src_desc.CalculateOffset(
// src_slice_origin_idx + all_indices[i]);
SrcData dst_v; SrcData dst_v;
...@@ -123,36 +122,10 @@ struct ThreadwiseTensorSliceTransfer_v1r3_using_space_filling_curve ...@@ -123,36 +122,10 @@ struct ThreadwiseTensorSliceTransfer_v1r3_using_space_filling_curve
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
// copy data from dst_vector into dst_buf // copy data from dst_vector into dst_buf
if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::Set) dst_buf.template Transfer<DstInMemOp, dst_vector_t>(
{
dst_buf.template Set<dst_vector_t>(
dst_coord_.GetOffset(),
is_dst_valid,
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
}
else if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::AtomicAdd)
{
dst_buf.template AtomicAdd<dst_vector_t>(
dst_coord_.GetOffset(),
is_dst_valid,
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
}
else if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::Add)
{
typename vector_type_maker<DstData, DstScalarPerVector>::type tmp;
tmp.template AsType<dst_vector_t>()(Number<0>{}) =
dst_buf.template Get<dst_vector_t>(dst_coord_.GetOffset(), is_dst_valid);
static_for<0, DstScalarPerVector, 1>{}([&](auto t) {
dst_vector.template AsType<DstData>()(t) += tmp.template AsType<DstData>()[t];
});
dst_buf.template Set<dst_vector_t>(
dst_coord_.GetOffset(), dst_coord_.GetOffset(),
is_dst_valid, is_dst_valid,
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]); dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
}
if constexpr(idx_1d.value != num_accesses - 1) if constexpr(idx_1d.value != num_accesses - 1)
{ {
...@@ -203,7 +176,8 @@ struct ThreadwiseTensorSliceTransfer_v1r3_using_space_filling_curve ...@@ -203,7 +176,8 @@ struct ThreadwiseTensorSliceTransfer_v1r3_using_space_filling_curve
remove_cv_t<decltype(dst_scalar_per_access)>>; remove_cv_t<decltype(dst_scalar_per_access)>>;
constexpr auto num_accesses = SpaceFillingCurve::GetNumOfAccess(); constexpr auto num_accesses = SpaceFillingCurve::GetNumOfAccess();
constexpr auto reset_step = SpaceFillingCurve::GetStepBetween(Number<num_accesses - 1>{}, I0); constexpr auto reset_step =
SpaceFillingCurve::GetStepBetween(Number<num_accesses - 1>{}, I0);
return reset_step; return reset_step;
} }
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "amd_buffer_addressing.hpp" #include "amd_buffer_addressing.hpp"
#include "c_style_pointer_cast.hpp" #include "c_style_pointer_cast.hpp"
#include "config.hpp"
#include "enable_if.hpp" #include "enable_if.hpp"
namespace ck { namespace ck {
...@@ -108,6 +109,31 @@ struct DynamicBuffer ...@@ -108,6 +109,31 @@ struct DynamicBuffer
} }
} }
template <InMemoryDataOperationEnum_t Op,
typename X,
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value,
bool>::type = false>
__host__ __device__ void Transfer(index_t i, bool is_valid_element, const X& x)
{
if constexpr(Op == InMemoryDataOperationEnum_t::Set)
{
this->template Set<X>(i, is_valid_element, x);
}
else if constexpr(Op == InMemoryDataOperationEnum_t::AtomicAdd)
{
this->template AtomicAdd<X>(i, is_valid_element, x);
}
else if constexpr(Op == InMemoryDataOperationEnum_t::Add)
{
auto tmp = this->template Get<X>(i, is_valid_element);
this->template Set<X>(i, is_valid_element, x+tmp);
// tmp += x;
// this->template Set<X>(i, is_valid_element, tmp);
}
}
template <typename X, template <typename X,
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type, typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value, typename scalar_type<remove_cvref_t<T>>::type>::value,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment