Commit 872ff12a authored by Paul's avatar Paul
Browse files

Use get_shape instead of result

parent 4566709b
...@@ -10,7 +10,7 @@ void auto_contiguous::apply(program& p) const ...@@ -10,7 +10,7 @@ void auto_contiguous::apply(program& p) const
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
shape s = ins->result; shape s = ins->get_shape();
if(not s.standard()) if(not s.standard())
{ {
auto c = p.insert_instruction(std::next(ins), contiguous{}, ins); auto c = p.insert_instruction(std::next(ins), contiguous{}, ins);
......
...@@ -17,7 +17,7 @@ void dead_code_elimination::apply(program& p) const ...@@ -17,7 +17,7 @@ void dead_code_elimination::apply(program& p) const
continue; continue;
const auto i = std::prev(ins); const auto i = std::prev(ins);
// Skip instruction with empty shape as output unless its a builtin // Skip instruction with empty shape as output unless its a builtin
if(i->result.elements() == 0 and not(i->name().front() == '@')) if(i->get_shape().elements() == 0 and not(i->name().front() == '@'))
continue; continue;
// Skip the last instruction // Skip the last instruction
if(i == last) if(i == last)
......
...@@ -172,7 +172,7 @@ inline shape compute_shape(const operation& op, const std::vector<instruction_re ...@@ -172,7 +172,7 @@ inline shape compute_shape(const operation& op, const std::vector<instruction_re
{ {
std::vector<shape> shapes(args.size()); std::vector<shape> shapes(args.size());
std::transform( std::transform(
args.begin(), args.end(), shapes.begin(), [](instruction_ref i) { return i->result; }); args.begin(), args.end(), shapes.begin(), [](instruction_ref i) { return i->get_shape(); });
return op.compute_shape(shapes); return op.compute_shape(shapes);
} }
......
...@@ -60,7 +60,7 @@ static void print_program(std::ostream& os, const program& p, F annonate) ...@@ -60,7 +60,7 @@ static void print_program(std::ostream& os, const program& p, F annonate)
os << ")"; os << ")";
} }
os << " -> " << ins->result; os << " -> " << ins->get_shape();
annonate(ins, names); annonate(ins, names);
...@@ -198,7 +198,7 @@ shape program::get_parameter_shape(std::string name) const ...@@ -198,7 +198,7 @@ shape program::get_parameter_shape(std::string name) const
} }
}); });
if(ins != this->end()) if(ins != this->end())
return ins->result; return ins->get_shape();
else else
return {}; return {};
} }
...@@ -211,7 +211,7 @@ std::unordered_map<std::string, shape> program::get_parameter_shapes() const ...@@ -211,7 +211,7 @@ std::unordered_map<std::string, shape> program::get_parameter_shapes() const
if(ins.name() == "@param") if(ins.name() == "@param")
{ {
auto&& name = any_cast<builtin::param>(ins.op).parameter; auto&& name = any_cast<builtin::param>(ins.op).parameter;
result[name] = ins.result; result[name] = ins.get_shape();
} }
} }
return result; return result;
...@@ -229,7 +229,7 @@ std::size_t program::size() const { return impl->instructions.size(); } ...@@ -229,7 +229,7 @@ std::size_t program::size() const { return impl->instructions.size(); }
instruction_ref program::begin() const { return impl->instructions.begin(); } instruction_ref program::begin() const { return impl->instructions.begin(); }
instruction_ref program::end() const { return impl->instructions.end(); } instruction_ref program::end() const { return impl->instructions.end(); }
shape program::get_shape() const { return impl->instructions.back().result; } shape program::get_shape() const { return impl->instructions.back().get_shape(); }
instruction_ref program::validate() const instruction_ref program::validate() const
{ {
...@@ -296,7 +296,7 @@ argument generic_eval(const program& p, ...@@ -296,7 +296,7 @@ argument generic_eval(const program& p,
} }
else if(ins->name() == "@outline") else if(ins->name() == "@outline")
{ {
results.emplace(ins, trace(ins, [&] { return argument{ins->result, nullptr}; })); results.emplace(ins, trace(ins, [&] { return argument{ins->get_shape(), nullptr}; }));
} }
else else
{ {
...@@ -309,7 +309,7 @@ argument generic_eval(const program& p, ...@@ -309,7 +309,7 @@ argument generic_eval(const program& p,
return results[i]; return results[i];
}); });
results.emplace(ins, results.emplace(ins,
trace(ins, [&] { return ins->op.compute(ctx, ins->result, values); })); trace(ins, [&] { return ins->op.compute(ctx, ins->get_shape(), values); }));
} }
assert(results.find(ins) != results.end()); assert(results.find(ins) != results.end());
} }
......
...@@ -44,7 +44,7 @@ void simplify_reshapes::apply(program& p) const ...@@ -44,7 +44,7 @@ void simplify_reshapes::apply(program& p) const
for(auto start : iterator_for(reshapes)) for(auto start : iterator_for(reshapes))
{ {
auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) { auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) {
return i->result == (*start)->result and i != (*start); return i->get_shape() == (*start)->get_shape() and i != (*start);
}); });
if(last != reshapes.rend()) if(last != reshapes.rend())
{ {
......
...@@ -372,10 +372,10 @@ struct miopen_apply ...@@ -372,10 +372,10 @@ struct miopen_apply
auto&& op = any_cast<convolution>(ins->op); auto&& op = any_cast<convolution>(ins->op);
auto conv = miopen_convolution{op, make_conv(op)}; auto conv = miopen_convolution{op, make_conv(op)};
auto ws = conv.compile(ctx, ins->result, ins->arguments); auto ws = conv.compile(ctx, ins->get_shape(), ins->arguments);
auto workspace = insert_allocation(ins, ws, "workspace"); auto workspace = insert_allocation(ins, ws, "workspace");
auto output = insert_allocation(ins, ins->result); auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction( return prog->replace_instruction(
ins, conv, ins->arguments.at(0), ins->arguments.at(1), workspace, output); ins, conv, ins->arguments.at(0), ins->arguments.at(1), workspace, output);
...@@ -385,7 +385,7 @@ struct miopen_apply ...@@ -385,7 +385,7 @@ struct miopen_apply
{ {
auto&& op = any_cast<pooling>(ins->op); auto&& op = any_cast<pooling>(ins->op);
auto pd = make_pooling(op); auto pd = make_pooling(op);
auto output = insert_allocation(ins, ins->result); auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction( return prog->replace_instruction(
ins, miopen_pooling{op, std::move(pd)}, ins->arguments.at(0), output); ins, miopen_pooling{op, std::move(pd)}, ins->arguments.at(0), output);
...@@ -397,7 +397,7 @@ struct miopen_apply ...@@ -397,7 +397,7 @@ struct miopen_apply
auto ad = make_relu(); auto ad = make_relu();
if(op.mode == "relu") if(op.mode == "relu")
{ {
auto output = insert_allocation(ins, ins->result); auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction( return prog->replace_instruction(
ins, miopen_relu{std::move(ad)}, ins->arguments.at(0), output); ins, miopen_relu{std::move(ad)}, ins->arguments.at(0), output);
} }
...@@ -406,7 +406,7 @@ struct miopen_apply ...@@ -406,7 +406,7 @@ struct miopen_apply
instruction_ref apply_add(instruction_ref ins) instruction_ref apply_add(instruction_ref ins)
{ {
auto output = insert_allocation(ins, ins->result); auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction( return prog->replace_instruction(
ins, hip_add{}, ins->arguments.at(0), ins->arguments.at(1), output); ins, hip_add{}, ins->arguments.at(0), ins->arguments.at(1), output);
} }
...@@ -414,7 +414,7 @@ struct miopen_apply ...@@ -414,7 +414,7 @@ struct miopen_apply
instruction_ref apply_gemm(instruction_ref ins) instruction_ref apply_gemm(instruction_ref ins)
{ {
auto&& op = any_cast<gemm>(ins->op); auto&& op = any_cast<gemm>(ins->op);
auto output = insert_allocation(ins, ins->result); auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction( return prog->replace_instruction(
ins, miopen_gemm{op}, ins->arguments.at(0), ins->arguments.at(1), output); ins, miopen_gemm{op}, ins->arguments.at(0), ins->arguments.at(1), output);
} }
...@@ -422,14 +422,14 @@ struct miopen_apply ...@@ -422,14 +422,14 @@ struct miopen_apply
instruction_ref apply_contiguous(instruction_ref ins) instruction_ref apply_contiguous(instruction_ref ins)
{ {
auto&& op = any_cast<contiguous>(ins->op); auto&& op = any_cast<contiguous>(ins->op);
auto output = insert_allocation(ins, ins->result); auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction(ins, miopen_contiguous{op}, ins->arguments.at(0), output); return prog->replace_instruction(ins, miopen_contiguous{op}, ins->arguments.at(0), output);
} }
instruction_ref apply_batch_norm_inference(instruction_ref ins) instruction_ref apply_batch_norm_inference(instruction_ref ins)
{ {
auto&& op = any_cast<batch_norm_inference>(ins->op); auto&& op = any_cast<batch_norm_inference>(ins->op);
auto output = insert_allocation(ins, ins->result); auto output = insert_allocation(ins, ins->get_shape());
shape old_shape = ins->arguments.at(1)->get_shape(); shape old_shape = ins->arguments.at(1)->get_shape();
std::vector<int64_t> new_shape{1, static_cast<int64_t>(old_shape.elements()), 1, 1}; std::vector<int64_t> new_shape{1, static_cast<int64_t>(old_shape.elements()), 1, 1};
auto reshape_op = reshape{new_shape}; auto reshape_op = reshape{new_shape};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment