Commit 872ff12a authored by Paul's avatar Paul
Browse files

Use get_shape instead of result

parent 4566709b
......@@ -10,7 +10,7 @@ void auto_contiguous::apply(program& p) const
{
for(auto ins : iterator_for(p))
{
shape s = ins->result;
shape s = ins->get_shape();
if(not s.standard())
{
auto c = p.insert_instruction(std::next(ins), contiguous{}, ins);
......
......@@ -17,7 +17,7 @@ void dead_code_elimination::apply(program& p) const
continue;
const auto i = std::prev(ins);
// Skip instruction with empty shape as output unless its a builtin
if(i->result.elements() == 0 and not(i->name().front() == '@'))
if(i->get_shape().elements() == 0 and not(i->name().front() == '@'))
continue;
// Skip the last instruction
if(i == last)
......
......@@ -172,7 +172,7 @@ inline shape compute_shape(const operation& op, const std::vector<instruction_re
{
std::vector<shape> shapes(args.size());
std::transform(
args.begin(), args.end(), shapes.begin(), [](instruction_ref i) { return i->result; });
args.begin(), args.end(), shapes.begin(), [](instruction_ref i) { return i->get_shape(); });
return op.compute_shape(shapes);
}
......
......@@ -60,7 +60,7 @@ static void print_program(std::ostream& os, const program& p, F annonate)
os << ")";
}
os << " -> " << ins->result;
os << " -> " << ins->get_shape();
annonate(ins, names);
......@@ -198,7 +198,7 @@ shape program::get_parameter_shape(std::string name) const
}
});
if(ins != this->end())
return ins->result;
return ins->get_shape();
else
return {};
}
......@@ -211,7 +211,7 @@ std::unordered_map<std::string, shape> program::get_parameter_shapes() const
if(ins.name() == "@param")
{
auto&& name = any_cast<builtin::param>(ins.op).parameter;
result[name] = ins.result;
result[name] = ins.get_shape();
}
}
return result;
......@@ -229,7 +229,7 @@ std::size_t program::size() const { return impl->instructions.size(); }
instruction_ref program::begin() const { return impl->instructions.begin(); }
instruction_ref program::end() const { return impl->instructions.end(); }
shape program::get_shape() const { return impl->instructions.back().result; }
shape program::get_shape() const { return impl->instructions.back().get_shape(); }
instruction_ref program::validate() const
{
......@@ -296,7 +296,7 @@ argument generic_eval(const program& p,
}
else if(ins->name() == "@outline")
{
results.emplace(ins, trace(ins, [&] { return argument{ins->result, nullptr}; }));
results.emplace(ins, trace(ins, [&] { return argument{ins->get_shape(), nullptr}; }));
}
else
{
......@@ -309,7 +309,7 @@ argument generic_eval(const program& p,
return results[i];
});
results.emplace(ins,
trace(ins, [&] { return ins->op.compute(ctx, ins->result, values); }));
trace(ins, [&] { return ins->op.compute(ctx, ins->get_shape(), values); }));
}
assert(results.find(ins) != results.end());
}
......
......@@ -44,7 +44,7 @@ void simplify_reshapes::apply(program& p) const
for(auto start : iterator_for(reshapes))
{
auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) {
return i->result == (*start)->result and i != (*start);
return i->get_shape() == (*start)->get_shape() and i != (*start);
});
if(last != reshapes.rend())
{
......
......@@ -372,10 +372,10 @@ struct miopen_apply
auto&& op = any_cast<convolution>(ins->op);
auto conv = miopen_convolution{op, make_conv(op)};
auto ws = conv.compile(ctx, ins->result, ins->arguments);
auto ws = conv.compile(ctx, ins->get_shape(), ins->arguments);
auto workspace = insert_allocation(ins, ws, "workspace");
auto output = insert_allocation(ins, ins->result);
auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction(
ins, conv, ins->arguments.at(0), ins->arguments.at(1), workspace, output);
......@@ -385,7 +385,7 @@ struct miopen_apply
{
auto&& op = any_cast<pooling>(ins->op);
auto pd = make_pooling(op);
auto output = insert_allocation(ins, ins->result);
auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction(
ins, miopen_pooling{op, std::move(pd)}, ins->arguments.at(0), output);
......@@ -397,7 +397,7 @@ struct miopen_apply
auto ad = make_relu();
if(op.mode == "relu")
{
auto output = insert_allocation(ins, ins->result);
auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction(
ins, miopen_relu{std::move(ad)}, ins->arguments.at(0), output);
}
......@@ -406,7 +406,7 @@ struct miopen_apply
instruction_ref apply_add(instruction_ref ins)
{
auto output = insert_allocation(ins, ins->result);
auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction(
ins, hip_add{}, ins->arguments.at(0), ins->arguments.at(1), output);
}
......@@ -414,7 +414,7 @@ struct miopen_apply
instruction_ref apply_gemm(instruction_ref ins)
{
auto&& op = any_cast<gemm>(ins->op);
auto output = insert_allocation(ins, ins->result);
auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction(
ins, miopen_gemm{op}, ins->arguments.at(0), ins->arguments.at(1), output);
}
......@@ -422,14 +422,14 @@ struct miopen_apply
instruction_ref apply_contiguous(instruction_ref ins)
{
auto&& op = any_cast<contiguous>(ins->op);
auto output = insert_allocation(ins, ins->result);
auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction(ins, miopen_contiguous{op}, ins->arguments.at(0), output);
}
instruction_ref apply_batch_norm_inference(instruction_ref ins)
{
auto&& op = any_cast<batch_norm_inference>(ins->op);
auto output = insert_allocation(ins, ins->result);
auto output = insert_allocation(ins, ins->get_shape());
shape old_shape = ins->arguments.at(1)->get_shape();
std::vector<int64_t> new_shape{1, static_cast<int64_t>(old_shape.elements()), 1, 1};
auto reshape_op = reshape{new_shape};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment