Commit 947ede80 authored by Paul's avatar Paul
Browse files

Fix tidy errors

parent c7888300
......@@ -91,9 +91,9 @@ instruction_ref program::insert_instruction(instruction_ref ins,
assert(not starts_with(op.name(), "@"));
// TODO: Use move
shape r = compute_shape(op, args);
auto result = impl->instructions.insert(ins, {op, r, args});
auto result = impl->instructions.insert(ins, {op, r, std::move(args)});
backreference(result);
assert(result->arguments == args);
// assert(result->arguments == args);
assert(result->valid(begin()));
return result;
}
......@@ -108,7 +108,7 @@ instruction_ref program::replace_instruction(instruction_ref ins,
assert(not starts_with(op.name(), "@"));
shape r = compute_shape(op, args);
ins->replace(op, r, args);
ins->replace(op, r, std::move(args));
backreference(ins);
assert(ins->valid(begin()));
return ins;
......
......@@ -7,7 +7,7 @@ namespace device {
void add_relu(argument result, argument arg1, argument arg2)
{
nary_standard(result, arg1, arg2)([](auto x, auto y) { return max(0, x + y); });
nary_standard(std::move(result), std::move(arg1), std::move(arg2))([](auto x, auto y) { return max(0, x + y); });
}
} // namespace device
......
......@@ -8,7 +8,7 @@ namespace device {
void contiguous(argument result, argument arg)
{
nary_nonstandard(result, arg)([](auto x) { return x; });
nary_nonstandard(std::move(result), std::move(arg))([](auto x) { return x; });
}
} // namespace device
......
......@@ -22,27 +22,34 @@ auto nary(argument result, Arguments... args)
};
}
template <class... Arguments>
auto nary_nonstandard(argument result, Arguments... args)
template <class F, class... Arguments>
auto nary_nonstandard_impl(F f, argument result, Arguments... args)
{
return [=](auto f) {
auto output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) {
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
auto data = make_sequence(
std::make_pair(hip_tensor_descriptor<ndim>{inputs.get_shape().lens(),
inputs.get_shape().strides()},
inputs.data())...);
hip_tensor_descriptor<ndim> out_desc(output_shape.lens(), output_shape.strides());
auto* outp = output.data();
gs_launch(output_shape.elements())([=](auto i) {
data([&](auto... ps) {
auto outidx = out_desc.multi(i);
outp[i] = f(ps.second[ps.first.linear(outidx)]...);
});
const auto& output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) {
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
auto data = make_sequence(
std::make_pair(hip_tensor_descriptor<ndim>{inputs.get_shape().lens(),
inputs.get_shape().strides()},
inputs.data())...);
hip_tensor_descriptor<ndim> out_desc(output_shape.lens(), output_shape.strides());
auto* outp = output.data();
gs_launch(output_shape.elements())([=](auto i) {
data([&](auto... ps) {
auto outidx = out_desc.multi(i);
outp[i] = f(ps.second[ps.first.linear(outidx)]...);
});
});
});
});
}
template <class... Arguments>
auto nary_nonstandard(argument result, Arguments... args)
{
return [=](auto f) {
return nary_nonstandard_impl(f, result, args...);
};
}
......@@ -51,7 +58,7 @@ auto nary_standard(argument result, Arguments... args)
{
return [=](auto f) {
// assert(x.get_shape().elements() == y.get_shape().elements());
auto output_shape = result.get_shape();
const auto& output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) {
auto data = make_sequence(inputs.data()...);
auto* outp = output.data();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment