"vscode:/vscode.git/clone" did not exist on "77212cc1da5c79de8b1e502e8df4842c7115fc41"
Commit 140fde0a authored by Shucai Xiao's avatar Shucai Xiao
Browse files

code refinement.

parent 43194a31
......@@ -18,11 +18,12 @@ struct binary : op_name<Derived>
return {s.type()};
return {s.type(), s.lens()};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
if(input1.get_shape().standard() and input2.get_shape().standard())
if(input1.get_shape().packed() and input2.get_shape().packed())
{
std::transform(input1.begin(),
input1.end(),
......@@ -38,6 +39,7 @@ struct binary : op_name<Derived>
});
}
});
return result;
}
};
......
......@@ -2,7 +2,7 @@
#define MIGRAPHX_GUARD_OPERATORS_CONVERT_HPP
#include <array>
#include <migraphx/op/binary.hpp>
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
......@@ -17,7 +17,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct convert
struct convert : unary<convert>
{
shape::type_t target_type = shape::half_type;
......@@ -27,23 +27,26 @@ struct convert
return pack(f(self.target_type, "target_type"));
}
std::string name() const { return "convert"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
return {target_type, inputs.front().lens(), inputs.front().strides()};
if (inputs.at(0).packed())
{
return {target_type, inputs.at(0).lens(), inputs.at(0).strides()};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
else
{
argument result{output_shape};
result.visit([&](auto output) {
args.front().visit(
[&](auto input) { std::copy(input.begin(), input.end(), output.begin()); });
});
return {target_type, inputs.at(0).lens()};
}
}
return result;
auto apply() const
{
return [](auto x) { return x; };
}
convert(shape::type_t t) : target_type{t} { }
convert() { }
};
} // namespace op
......
......@@ -15,25 +15,31 @@ struct unary : op_name<Derived>
check_shapes{inputs}.has(1);
return inputs.at(0);
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
if(input.get_shape().standard())
result.visit([&](auto output) {
args[0].visit([&](auto input) {
if(input.get_shape().packed())
{
std::transform(input.begin(),
input.end(),
output.begin(),
static_cast<const Derived&>(*this).apply());
return result;
}
else
{
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) =
static_cast<const Derived&>(*this).apply()(input(idx.begin(), idx.end()));
});
}
return result;
});
});
return result;
}
};
......
......@@ -29,10 +29,10 @@ TEST_CASE(param_add)
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
auto p1 = p.add_parameter("x", s);
auto hp1 = p.insert_instruction(
std::next(p1), migraphx::op::convert{migraphx::shape::half_type}, p1);
std::next(p1), migraphx::op::convert{}, p1);
auto p2 = p.add_parameter("y", s);
auto hp2 = p.insert_instruction(
std::next(p2), migraphx::op::convert{migraphx::shape::half_type}, p2);
std::next(p2), migraphx::op::convert{}, p2);
auto hs = p.add_instruction(migraphx::op::add{}, hp1, hp2);
p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, hs);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment