Commit 3dc42104 authored by wsttiger's avatar wsttiger
Browse files

Added contiguous operator on CPU (called from GPU)

parent 8addb9d5
...@@ -173,6 +173,23 @@ struct miopen_gemm ...@@ -173,6 +173,23 @@ struct miopen_gemm
} }
}; };
struct miopen_contiguous
{
contiguous op;
std::string name() const { return "miopen::contiguous"; }
shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); }
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, from_gpu(args[0]))([&](auto output, auto input) {
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = input(idx.begin(), idx.end());
});
});
return to_gpu(result);
}
};
struct miopen_relu struct miopen_relu
{ {
shared<activation_descriptor> ad; shared<activation_descriptor> ad;
...@@ -231,6 +248,10 @@ struct miopen_apply ...@@ -231,6 +248,10 @@ struct miopen_apply
{ {
apply_gemm(it); apply_gemm(it);
} }
else if(it->op.name() == "contiguous")
{
apply_contiguous(it);
}
} }
} }
...@@ -297,6 +318,13 @@ struct miopen_apply ...@@ -297,6 +318,13 @@ struct miopen_apply
prog->replace_instruction( prog->replace_instruction(
ins, miopen_gemm{op}, ins->arguments.at(0), ins->arguments.at(1), output); ins, miopen_gemm{op}, ins->arguments.at(0), ins->arguments.at(1), output);
} }
void apply_contiguous(instruction_ref ins)
{
auto&& op = any_cast<contiguous>(ins->op);
auto output = insert_allocation(ins, ins->result);
prog->replace_instruction(ins, miopen_contiguous{op}, ins->arguments.at(0), output);
}
}; };
struct miopen_pass struct miopen_pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment