Commit 1f8aa24f authored by Paul's avatar Paul
Browse files

Format

parent 4d6a1a8b
......@@ -61,15 +61,16 @@ MIGRAPHX_REGISTER_OP(miopen_op);
void compile_miopen::apply(module& m) const
{
for(auto ins:iterator_for(m))
for(auto ins : iterator_for(m))
{
if (ins->name() != "gpu::miopen_op")
if(ins->name() != "gpu::miopen_op")
continue;
auto op = any_cast<miopen_op>(ins->get_operator()).op;
auto v = op.compile(*ctx, ins->get_shape(), to_shapes(ins->inputs()));
auto op = any_cast<miopen_op>(ins->get_operator()).op;
auto v = op.compile(*ctx, ins->get_shape(), to_shapes(ins->inputs()));
std::size_t ws = v.get("workspace", 0);
auto inputs = ins->inputs();
auto alloc = m.insert_instruction(ins, make_op("allocate", {{"shape", to_value(shape{shape::int8_type, {ws}})}}));
auto inputs = ins->inputs();
auto alloc = m.insert_instruction(
ins, make_op("allocate", {{"shape", to_value(shape{shape::int8_type, {ws}})}}));
inputs.insert(std::prev(inputs.end()), alloc);
m.replace_instruction(ins, op, inputs);
......
......@@ -216,7 +216,9 @@ shape miopen_convolution::find(context& ctx, const shape& output_shape, std::vec
#endif
}
value miopen_convolution::compile(context& ctx, const shape& output, const std::vector<shape>& input)
value miopen_convolution::compile(context& ctx,
const shape& output,
const std::vector<shape>& input)
{
if(cd == nullptr)
cd = make_conv(op);
......
......@@ -240,9 +240,12 @@ struct miopen_apply
// TODO: Use make_op
operation conv = miopen_convolution{op};
auto output = insert_allocation(ins, ins->get_shape());
return mod->replace_instruction(
ins, make_op("gpu::miopen_op", {{"op", to_value(conv)}}), ins->inputs().at(0), ins->inputs().at(1), output);
return mod->replace_instruction(ins,
make_op("gpu::miopen_op", {{"op", to_value(conv)}}),
ins->inputs().at(0),
ins->inputs().at(1),
output);
});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment