Commit 947ede80 authored by Paul's avatar Paul
Browse files

Fix tidy errors

parent c7888300
...@@ -91,9 +91,9 @@ instruction_ref program::insert_instruction(instruction_ref ins, ...@@ -91,9 +91,9 @@ instruction_ref program::insert_instruction(instruction_ref ins,
assert(not starts_with(op.name(), "@")); assert(not starts_with(op.name(), "@"));
// TODO: Use move // TODO: Use move
shape r = compute_shape(op, args); shape r = compute_shape(op, args);
auto result = impl->instructions.insert(ins, {op, r, args}); auto result = impl->instructions.insert(ins, {op, r, std::move(args)});
backreference(result); backreference(result);
assert(result->arguments == args); // assert(result->arguments == args);
assert(result->valid(begin())); assert(result->valid(begin()));
return result; return result;
} }
...@@ -108,7 +108,7 @@ instruction_ref program::replace_instruction(instruction_ref ins, ...@@ -108,7 +108,7 @@ instruction_ref program::replace_instruction(instruction_ref ins,
assert(not starts_with(op.name(), "@")); assert(not starts_with(op.name(), "@"));
shape r = compute_shape(op, args); shape r = compute_shape(op, args);
ins->replace(op, r, args); ins->replace(op, r, std::move(args));
backreference(ins); backreference(ins);
assert(ins->valid(begin())); assert(ins->valid(begin()));
return ins; return ins;
......
...@@ -7,7 +7,7 @@ namespace device { ...@@ -7,7 +7,7 @@ namespace device {
void add_relu(argument result, argument arg1, argument arg2) void add_relu(argument result, argument arg1, argument arg2)
{ {
nary_standard(result, arg1, arg2)([](auto x, auto y) { return max(0, x + y); }); nary_standard(std::move(result), std::move(arg1), std::move(arg2))([](auto x, auto y) { return max(0, x + y); });
} }
} // namespace device } // namespace device
......
...@@ -8,7 +8,7 @@ namespace device { ...@@ -8,7 +8,7 @@ namespace device {
void contiguous(argument result, argument arg) void contiguous(argument result, argument arg)
{ {
nary_nonstandard(result, arg)([](auto x) { return x; }); nary_nonstandard(std::move(result), std::move(arg))([](auto x) { return x; });
} }
} // namespace device } // namespace device
......
...@@ -22,27 +22,34 @@ auto nary(argument result, Arguments... args) ...@@ -22,27 +22,34 @@ auto nary(argument result, Arguments... args)
}; };
} }
template <class... Arguments> template <class F, class... Arguments>
auto nary_nonstandard(argument result, Arguments... args) auto nary_nonstandard_impl(F f, argument result, Arguments... args)
{ {
return [=](auto f) { const auto& output_shape = result.get_shape();
auto output_shape = result.get_shape(); visit_all(result, args...)([&](auto output, auto... inputs) {
visit_all(result, args...)([&](auto output, auto... inputs) { visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) { auto data = make_sequence(
auto data = make_sequence( std::make_pair(hip_tensor_descriptor<ndim>{inputs.get_shape().lens(),
std::make_pair(hip_tensor_descriptor<ndim>{inputs.get_shape().lens(), inputs.get_shape().strides()},
inputs.get_shape().strides()}, inputs.data())...);
inputs.data())...); hip_tensor_descriptor<ndim> out_desc(output_shape.lens(), output_shape.strides());
hip_tensor_descriptor<ndim> out_desc(output_shape.lens(), output_shape.strides()); auto* outp = output.data();
auto* outp = output.data(); gs_launch(output_shape.elements())([=](auto i) {
gs_launch(output_shape.elements())([=](auto i) { data([&](auto... ps) {
data([&](auto... ps) { auto outidx = out_desc.multi(i);
auto outidx = out_desc.multi(i); outp[i] = f(ps.second[ps.first.linear(outidx)]...);
outp[i] = f(ps.second[ps.first.linear(outidx)]...);
});
}); });
}); });
}); });
});
}
template <class... Arguments>
auto nary_nonstandard(argument result, Arguments... args)
{
return [=](auto f) {
return nary_nonstandard_impl(f, result, args...);
}; };
} }
...@@ -51,7 +58,7 @@ auto nary_standard(argument result, Arguments... args) ...@@ -51,7 +58,7 @@ auto nary_standard(argument result, Arguments... args)
{ {
return [=](auto f) { return [=](auto f) {
// assert(x.get_shape().elements() == y.get_shape().elements()); // assert(x.get_shape().elements() == y.get_shape().elements());
auto output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) { visit_all(result, args...)([&](auto output, auto... inputs) {
auto data = make_sequence(inputs.data()...); auto data = make_sequence(inputs.data()...);
auto* outp = output.data(); auto* outp = output.data();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment