"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "b111aa5a91769e5af0edf7259773b20514f9883f"
Commit 44e87b4e authored by rocking's avatar rocking
Browse files

Implement reduction meand and reduction square mean

parent ac543313
...@@ -45,7 +45,7 @@ using D1ReduceOp = ck::reduce::Add<ReduceAccDataType>; ...@@ -45,7 +45,7 @@ using D1ReduceOp = ck::reduce::Add<ReduceAccDataType>;
using DxsReduceOp = ck::Tuple<D0ReduceOp, D1ReduceOp>; using DxsReduceOp = ck::Tuple<D0ReduceOp, D1ReduceOp>;
using UnaryIdenticElementOp = using UnaryIdenticElementOp =
ck::tensor_operation::element_wise::UnaryIdentic<ReduceAccDataType, ReduceAccDataType, false>; ck::tensor_operation::element_wise::UnaryIdentic<ReduceAccDataType, ReduceAccDataType, true>;
using UnarySquareElementOp = using UnarySquareElementOp =
ck::tensor_operation::element_wise::UnarySquare<ReduceAccDataType, ReduceAccDataType, false>; ck::tensor_operation::element_wise::UnarySquare<ReduceAccDataType, ReduceAccDataType, false>;
using DxsInElementOp = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>; using DxsInElementOp = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
...@@ -181,6 +181,9 @@ int main(int argc, char* argv[]) ...@@ -181,6 +181,9 @@ int main(int argc, char* argv[])
auto dxs_global = ck::make_tuple(static_cast<DDataType*>(d0_device_buf.GetDeviceBuffer()), auto dxs_global = ck::make_tuple(static_cast<DDataType*>(d0_device_buf.GetDeviceBuffer()),
static_cast<DDataType*>(d1_device_buf.GetDeviceBuffer())); static_cast<DDataType*>(d1_device_buf.GetDeviceBuffer()));
auto dxs_in_element_op = DxsInElementOp{};
auto dxs_out_element_op = DxsOutElementOp{M, M};
// do GEMM // do GEMM
auto gemm = DeviceGemmReduceInstance{}; auto gemm = DeviceGemmReduceInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
...@@ -197,8 +200,8 @@ int main(int argc, char* argv[]) ...@@ -197,8 +200,8 @@ int main(int argc, char* argv[])
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
DxsInElementOp{}, dxs_in_element_op,
DxsOutElementOp{}); dxs_out_element_op);
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
...@@ -256,12 +259,14 @@ int main(int argc, char* argv[]) ...@@ -256,12 +259,14 @@ int main(int argc, char* argv[])
float d0_val = 0; float d0_val = 0;
float d1_val = 0; float d1_val = 0;
UnaryIdenticElementOp{}(d0_val, c_val); dxs_in_element_op(ck::Number<0>{})(d0_val, c_val);
UnarySquareElementOp{}(d1_val, c_val); dxs_in_element_op(ck::Number<1>{})(d1_val, c_val);
d0_reduce_op(d0_acc, d0_val); d0_reduce_op(d0_acc, d0_val);
d1_reduce_op(d1_acc, d1_val); d1_reduce_op(d1_acc, d1_val);
} }
dxs_out_element_op(ck::Number<0>{})(d0_acc, d0_acc);
dxs_out_element_op(ck::Number<1>{})(d1_acc, d1_acc);
d0_m_host_result(m) = ck::type_convert<DDataType>(d0_acc); d0_m_host_result(m) = ck::type_convert<DDataType>(d0_acc);
d1_m_host_result(m) = ck::type_convert<DDataType>(d1_acc); d1_m_host_result(m) = ck::type_convert<DDataType>(d1_acc);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment