Commit b7bc3c2b authored by Jing Zhang's avatar Jing Zhang
Browse files

allow packed elementwise_op

parent d61d9edf
......@@ -44,10 +44,16 @@ using ELayout = Row;
struct MultiATest
{
template <typename A, typename A0, typename A1>
__host__ __device__ constexpr void operator()(A& a, const A0& a0, const A1& a1) const
__host__ __device__ constexpr void operator()(A& a, const A0& a0, const A1& a1) const;
template <>
__host__ __device__ constexpr void
operator()(ck::half2_t& a, const ck::half2_t& a0, const ck::half2_t& a1) const
{
a = (a0 + a1) / 2;
}
static constexpr ck::index_t vec_len = 2;
};
struct AlphaBetaAdd
......
......@@ -8,6 +8,18 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
#include <type_traits>
template <typename T, typename = void>
struct has_vec_len : std::false_type
{
};
template <typename T>
struct has_vec_len<T, std::void_t<decltype(std::declval<T>().vec_len)>> : std::true_type
{
};
namespace ck {
// Thread-level multi-source, multi-destination tensor slice data movement
......@@ -131,7 +143,6 @@ struct ThreadwiseTensorSliceTransfer_v7r2
Number<num>{});
}
#if 1
// SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
// SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...>
template <typename SrcBuffers,
......@@ -143,7 +154,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
auto src_vectors = generate_vectors<SrcDatas, SrcScalarPerVector>();
auto dst_vectors = generate_vectors<DstDatas, DstScalarPerVector>();
#if 0
// copy data from src_bufs into src_vectors
static_for<0, nSrc, 1>{}([&](auto i) {
using src_vector_t = typename remove_cvref_t<decltype(src_vectors[i])>::type;
......@@ -155,51 +166,94 @@ struct ThreadwiseTensorSliceTransfer_v7r2
src_bufs[i].template Get<src_vector_t>(src_coords_[i].GetOffset(),
is_src_valid);
});
#endif
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
const auto src_data_refs = generate_tie(
// return type should be lvalue
[&](auto iSrc) -> const auto& {
// copy data from src_bufs into src_vectors
using src_vector_t =
typename remove_cvref_t<decltype(src_vectors[iSrc])>::type;
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_descs[iSrc], src_coords_[iSrc]);
src_vectors(iSrc).template AsType<src_vector_t>()(I0) =
src_bufs[iSrc].template Get<src_vector_t>(src_coords_[iSrc].GetOffset(),
is_src_valid);
// get reference to src data
using SrcData = remove_cvref_t<tuple_element_t<iSrc.value, SrcDatas>>;
if constexpr(!has_vec_len<decltype(element_op_)>::value)
{
// apply pointwise function
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
// get reference to src data
const auto src_data_refs = generate_tie(
// return type should be lvalue
[&](auto iSrc) -> const auto& {
using SrcData = remove_cvref_t<tuple_element_t<iSrc.value, SrcDatas>>;
return src_vectors[iSrc].template AsType<SrcData>()[i];
},
Number<nSrc>{});
// get reference to dst data
auto dst_data_refs = generate_tie(
// return type should be lvalue
[&](auto iDst) -> auto& {
using DstData = remove_cvref_t<tuple_element_t<iDst.value, DstDatas>>;
return dst_vectors(iDst).template AsType<DstData>()(i);
},
Number<nDst>{});
// apply pointwise function
// pointwise function signature:
// element_op_(dst_data_refs[I0],
// dst_data_refs[I1],
// ...,
// src_data_refs[I0],
// src_data_refs[I1],
// ...)
unpack2(element_op_, dst_data_refs, src_data_refs);
});
}
else
{
constexpr auto elem_op_vec_len = decltype(element_op_)::vec_len;
return src_vectors[iSrc].template AsType<SrcData>()[i];
},
Number<nSrc>{});
static_assert(is_same<remove_cvref_t<decltype(elem_op_vec_len)>, index_t>::value,
"vec_len in element_op_ type is not index_t");
// get reference to dst data
auto dst_data_refs = generate_tie(
// return type should be lvalue
[&](auto iDst) -> auto& {
using DstData = remove_cvref_t<tuple_element_t<iDst.value, DstDatas>>;
static_assert(elem_op_vec_len == 2 || elem_op_vec_len == 4 || elem_op_vec_len == 8,
"vec_len in element_op_ must be 2, 4, 8");
return dst_vectors(iDst).template AsType<DstData>()(i);
},
Number<nDst>{});
static_assert(SrcScalarPerVector % elem_op_vec_len == 0,
"vec_len in element_op_ cannot be divided by SrcScalarPerVector!");
// apply pointwise function
// pointwise function signature:
// element_op_(dst_data_refs[I0],
// dst_data_refs[I1],
// ...,
// src_data_refs[I0],
// src_data_refs[I1],
// ...)
unpack2(element_op_, dst_data_refs, src_data_refs);
});
static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto i) {
// get reference to src data
const auto src_data_refs = generate_tie(
// return type should be lvalue
[&](auto iSrc) -> const auto& {
using SrcData = remove_cvref_t<tuple_element_t<iSrc.value, SrcDatas>>;
using elem_op_vec_t =
typename vector_type<SrcData, elem_op_vec_len>::type;
return src_vectors[iSrc].template AsType<elem_op_vec_t>()[i];
},
Number<nSrc>{});
// get reference to dst data
auto dst_data_refs = generate_tie(
// return type should be lvalue
[&](auto iDst) -> auto& {
using DstData = remove_cvref_t<tuple_element_t<iDst.value, DstDatas>>;
using elem_op_vec_t =
typename vector_type<DstData, elem_op_vec_len>::type;
return dst_vectors(iDst).template AsType<elem_op_vec_t>()(i);
},
Number<nDst>{});
// apply pointwise function
// pointwise function signature:
// element_op_(dst_data_refs[I0],
// dst_data_refs[I1],
// ...,
// src_data_refs[I0],
// src_data_refs[I1],
// ...)
unpack2(element_op_, dst_data_refs, src_data_refs);
});
}
dst_vectors_tuple_(iAccess) = dst_vectors;
......@@ -227,9 +281,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
}
});
}
#endif
#if 1
// DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...>
// DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...>
template <typename DstBuffers,
......@@ -280,7 +332,6 @@ struct ThreadwiseTensorSliceTransfer_v7r2
}
});
}
#endif
// SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
// SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...>
......
......@@ -12,7 +12,9 @@ cmake
-save-temps=$PWD" \
-D CMAKE_BUILD_TYPE=Release \
-D BUILD_DEV=ON \
-D GPU_TARGETS="gfx908;gfx90a;gfx940" \
-D GPU_TARGETS="gfx908" \
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
-D USE_BITINT_EXTENSION_INT4=OFF \
${MY_PROJECT_SOURCE}
#-D GPU_TARGETS="gfx908;gfx90a;gfx940" \
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment