Commit d807d05e authored by Jing Zhang's avatar Jing Zhang
Browse files

add scaleAdd_vec4 example

parent ca0f9579
......@@ -41,21 +41,42 @@ using BLayout = Col;
using DLayout = Row;
using ELayout = Row;
struct Add
struct AddScale
{
static constexpr auto I0 = ck::Number<0>{};
static constexpr auto I1 = ck::Number<1>{};
static constexpr auto I2 = ck::Number<2>{};
static constexpr auto I3 = ck::Number<3>{};
__host__ __device__ constexpr void
operator()(ck::half2_t& a, const ck::half2_t& a0, const ck::half2_t& a1) const
operator()(ck::half4_t& a, const ck::half4_t& a0, const ck::half4_t& a1) const
{
a = a0 + a1;
const auto a0_v_t = ck::vector_type<ck::half_t, 4>{a0};
const auto a1_v_t = ck::vector_type<ck::half_t, 4>{a1};
auto r_v_t = ck::vector_type<ck::half_t, 4>{};
r_v_t.AsType<ck::half_t>()(I0) =
scale * (a0_v_t.AsType<ck::half_t>()[I0] + a1_v_t.AsType<ck::half_t>()[I0]);
r_v_t.AsType<ck::half_t>()(I1) =
scale * (a0_v_t.AsType<ck::half_t>()[I1] + a1_v_t.AsType<ck::half_t>()[I1]);
r_v_t.AsType<ck::half_t>()(I2) =
scale * (a0_v_t.AsType<ck::half_t>()[I2] + a1_v_t.AsType<ck::half_t>()[I2]);
r_v_t.AsType<ck::half_t>()(I3) =
scale * (a0_v_t.AsType<ck::half_t>()[I3] + a1_v_t.AsType<ck::half_t>()[I3]);
a = r_v_t.AsType<ck::half4_t>()[I0];
}
__host__ __device__ constexpr void
operator()(ck::half_t& a, const ck::half_t& a0, const ck::half_t& a1) const
{
a = a0 + a1;
a = scale * (a0 + a1);
}
static constexpr ck::index_t vec_len = 4;
float scale = 1.0;
};
struct AlphaBetaAdd
......@@ -76,7 +97,7 @@ struct AlphaBetaAdd
float beta_;
};
using AElementOp = Add;
using AElementOp = AddScale;
using BElementOp = PassThrough;
using CDEElementOp = AlphaBetaAdd;
......@@ -248,7 +269,7 @@ int main(int argc, char* argv[])
d_device_buf.ToDevice(d_m_n.mData.data());
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
auto a_element_op = AElementOp{};
auto a_element_op = AElementOp{0.2};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{alpha, beta};
......@@ -312,14 +333,14 @@ int main(int argc, char* argv[])
BDataType,
CShuffleDataType,
AccDataType,
BElementOp,
PassThrough,
BElementOp,
PassThrough>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument =
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, b_element_op, b_element_op, PassThrough{});
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument);
......
......@@ -133,7 +133,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
}
template <typename T>
using has_vec_len = decltype(std::declval<T&>().vec_len());
using has_vec_len = decltype(std::declval<T&>().vec_len);
// SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
// SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment