Commit 759daeb6 authored by Paul's avatar Paul
Browse files

Remove cpu verion for miopen_add

parent 9fed6960
......@@ -97,8 +97,8 @@ struct check_shapes
const check_shapes& not_broadcasted() const
{
// if(!this->all_of([](const shape& s) { return not s.broadcasted(); }))
// MIGRAPH_THROW(prefix() + "Shapes are broadcasted");
if(!this->all_of([](const shape& s) { return not s.broadcasted(); }))
MIGRAPH_THROW(prefix() + "Shapes are broadcasted");
return *this;
}
......
......@@ -180,38 +180,22 @@ struct miopen_add
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
{
if(args[1].get_shape().broadcasted())
{
argument result{output_shape};
visit_all(result, from_gpu(args[0]), from_gpu(args[1]))(
[&](auto output, auto input1, auto input2) {
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) =
input1(idx.begin(), idx.end()) + input2(idx.begin(), idx.end());
});
});
return to_gpu(result);
}
else
{
float alpha = 1, beta = 0;
auto a_desc = make_tensor(args[0].get_shape());
auto b_desc = make_tensor(args[1].get_shape());
auto c_desc = make_tensor(output_shape);
miopenOpTensor(ctx.handle.get(),
miopenTensorOpAdd,
&alpha,
a_desc.get(),
args[0].implicit(),
&alpha,
b_desc.get(),
args[1].implicit(),
&beta,
c_desc.get(),
args[2].implicit());
return args[2];
}
float alpha = 1, beta = 0;
auto a_desc = make_tensor(args[0].get_shape());
auto b_desc = make_tensor(args[1].get_shape());
auto c_desc = make_tensor(output_shape);
miopenOpTensor(ctx.handle.get(),
miopenTensorOpAdd,
&alpha,
a_desc.get(),
args[0].implicit(),
&alpha,
b_desc.get(),
args[1].implicit(),
&beta,
c_desc.get(),
args[2].implicit());
return args[2];
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment