Commit 140fde0a authored by Shucai Xiao's avatar Shucai Xiao
Browse files

code refinement.

parent 43194a31
...@@ -18,11 +18,12 @@ struct binary : op_name<Derived> ...@@ -18,11 +18,12 @@ struct binary : op_name<Derived>
return {s.type()}; return {s.type()};
return {s.type(), s.lens()}; return {s.type(), s.lens()};
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) { visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
if(input1.get_shape().standard() and input2.get_shape().standard()) if(input1.get_shape().packed() and input2.get_shape().packed())
{ {
std::transform(input1.begin(), std::transform(input1.begin(),
input1.end(), input1.end(),
...@@ -38,6 +39,7 @@ struct binary : op_name<Derived> ...@@ -38,6 +39,7 @@ struct binary : op_name<Derived>
}); });
} }
}); });
return result; return result;
} }
}; };
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define MIGRAPHX_GUARD_OPERATORS_CONVERT_HPP #define MIGRAPHX_GUARD_OPERATORS_CONVERT_HPP
#include <array> #include <array>
#include <migraphx/op/binary.hpp> #include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp> #include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
...@@ -17,7 +17,7 @@ namespace migraphx { ...@@ -17,7 +17,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct convert struct convert : unary<convert>
{ {
shape::type_t target_type = shape::half_type; shape::type_t target_type = shape::half_type;
...@@ -27,23 +27,26 @@ struct convert ...@@ -27,23 +27,26 @@ struct convert
return pack(f(self.target_type, "target_type")); return pack(f(self.target_type, "target_type"));
} }
std::string name() const { return "convert"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this}.has(1);
return {target_type, inputs.front().lens(), inputs.front().strides()}; if (inputs.at(0).packed())
{
return {target_type, inputs.at(0).lens(), inputs.at(0).strides()};
} }
else
argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; return {target_type, inputs.at(0).lens()};
result.visit([&](auto output) { }
args.front().visit( }
[&](auto input) { std::copy(input.begin(), input.end(), output.begin()); });
});
return result; auto apply() const
{
return [](auto x) { return x; };
} }
convert(shape::type_t t) : target_type{t} { }
convert() { }
}; };
} // namespace op } // namespace op
......
...@@ -15,25 +15,31 @@ struct unary : op_name<Derived> ...@@ -15,25 +15,31 @@ struct unary : op_name<Derived>
check_shapes{inputs}.has(1); check_shapes{inputs}.has(1);
return inputs.at(0); return inputs.at(0);
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) { result.visit([&](auto output) {
if(input.get_shape().standard()) args[0].visit([&](auto input) {
if(input.get_shape().packed())
{ {
std::transform(input.begin(), std::transform(input.begin(),
input.end(), input.end(),
output.begin(), output.begin(),
static_cast<const Derived&>(*this).apply()); static_cast<const Derived&>(*this).apply());
return result;
} }
else
{
shape_for_each(output.get_shape(), [&](const auto& idx) { shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = output(idx.begin(), idx.end()) =
static_cast<const Derived&>(*this).apply()(input(idx.begin(), idx.end())); static_cast<const Derived&>(*this).apply()(input(idx.begin(), idx.end()));
}); });
}
return result;
}); });
});
return result; return result;
} }
}; };
......
...@@ -29,10 +29,10 @@ TEST_CASE(param_add) ...@@ -29,10 +29,10 @@ TEST_CASE(param_add)
migraphx::shape s{migraphx::shape::float_type, {2, 3}}; migraphx::shape s{migraphx::shape::float_type, {2, 3}};
auto p1 = p.add_parameter("x", s); auto p1 = p.add_parameter("x", s);
auto hp1 = p.insert_instruction( auto hp1 = p.insert_instruction(
std::next(p1), migraphx::op::convert{migraphx::shape::half_type}, p1); std::next(p1), migraphx::op::convert{}, p1);
auto p2 = p.add_parameter("y", s); auto p2 = p.add_parameter("y", s);
auto hp2 = p.insert_instruction( auto hp2 = p.insert_instruction(
std::next(p2), migraphx::op::convert{migraphx::shape::half_type}, p2); std::next(p2), migraphx::op::convert{}, p2);
auto hs = p.add_instruction(migraphx::op::add{}, hp1, hp2); auto hs = p.add_instruction(migraphx::op::add{}, hp1, hp2);
p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, hs); p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, hs);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment