Commit 41f3973d authored by rocking's avatar rocking
Browse files

Add reduce mean and square mean

parent b42dbb47
......@@ -45,11 +45,13 @@ using D1ReduceOp = ck::reduce::Add<ReduceAccDataType>;
using DxsReduceOp = ck::Tuple<D0ReduceOp, D1ReduceOp>;
using UnaryIdenticElementOp =
ck::tensor_operation::element_wise::UnaryIdentic<ReduceAccDataType, ReduceAccDataType, false>;
using UnaryDivElementOp =
ck::tensor_operation::element_wise::UnaryIdentic<ReduceAccDataType, ReduceAccDataType, true>;
using UnarySquareElementOp =
ck::tensor_operation::element_wise::UnarySquare<ReduceAccDataType, ReduceAccDataType, false>;
using DxsInElementOp = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
using DxsOutElementOp = ck::Tuple<UnaryIdenticElementOp, UnaryIdenticElementOp>;
using DxsOutElementOp = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
using DGlobalMemOp =
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
......
......@@ -24,10 +24,11 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add<F32>;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Div = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, true>;
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>;
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>;
using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Identity, Identity>;
using DOutElementOps = ck::Tuple<Div, Div>;
using ReduceMemOp = ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>;
......
......@@ -24,10 +24,11 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add<F32>;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Div = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, true>;
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>;
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>;
using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Identity, Identity>;
using DOutElementOps = ck::Tuple<Div, Div>;
using ReduceMemOp = ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>;
......
......@@ -24,10 +24,11 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add<F32>;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Div = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, true>;
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>;
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>;
using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Identity, Identity>;
using DOutElementOps = ck::Tuple<Div, Div>;
using ReduceMemOp = ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>;
......
......@@ -24,10 +24,11 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add<F32>;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Div = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, true>;
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>;
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>;
using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Identity, Identity>;
using DOutElementOps = ck::Tuple<Div, Div>;
using ReduceMemOp = ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>;
......
......@@ -19,6 +19,7 @@ namespace device_gemm_instance {
using F32 = float;
using F16 = ck::half_t;
using DPtrsGlobal = ck::Tuple<F32*, F32*>;
using Div = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, true>;
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>;
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>;
using DInElementOps = ck::Tuple<Identity, Square>;
......@@ -122,11 +123,12 @@ bool profile_gemm_reduce_impl(int do_verification,
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread);
}
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using D0ReduceOp = ck::reduce::Add<float>;
using D1ReduceOp = ck::reduce::Add<float>;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using D0ReduceOp = ck::reduce::Add<float>;
using D1ReduceOp = ck::reduce::Add<float>;
using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryIdentic<float, float, true>;
using UnaryIdenticElementOp =
ck::tensor_operation::element_wise::UnaryIdentic<float, float, false>;
using UnarySquareElementOp =
......@@ -138,7 +140,7 @@ bool profile_gemm_reduce_impl(int do_verification,
const auto b_element_op = BElementOp{};
const auto c_element_op = CElementOp{};
const auto dxs_in_element_op = DxsInElementOps{};
const auto dxs_out_element_op = DxsOutElementOps{};
const auto dxs_out_element_op = DxsOutElementOp{M, M};
const auto d0_reduce_op = D0ReduceOp{};
const auto d1_reduce_op = D1ReduceOp{};
......@@ -170,6 +172,8 @@ bool profile_gemm_reduce_impl(int do_verification,
d1_reduce_op(d1_acc, d1_val);
}
dxs_out_element_op(ck::Number<0>{})(d0_acc, d0_acc);
dxs_out_element_op(ck::Number<1>{})(d1_acc, d1_acc);
d0_m_host_result(m) = ck::type_convert<DDataType>(d0_acc);
d1_m_host_result(m) = ck::type_convert<DDataType>(d1_acc);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment