Commit 682b524e authored by Paul's avatar Paul
Browse files

Dont lower transpose and reshape

parent 44513aca
...@@ -221,9 +221,9 @@ struct transpose ...@@ -221,9 +221,9 @@ struct transpose
} }
return {t, output_lens, output_strides}; return {t, output_lens, output_strides};
} }
argument compute(context&, shape, std::vector<argument>) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
MIGRAPH_THROW("not computable"); return {output_shape, std::move(args.front().data)};
} }
}; };
...@@ -286,9 +286,9 @@ struct reshape ...@@ -286,9 +286,9 @@ struct reshape
return s; return s;
} }
argument compute(context&, shape, std::vector<argument>) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
MIGRAPH_THROW("not computable"); return {output_shape, std::move(args.front().data)};
} }
friend std::ostream& operator<<(std::ostream& os, const reshape& op) friend std::ostream& operator<<(std::ostream& os, const reshape& op)
......
...@@ -185,18 +185,6 @@ struct cpu_pooling ...@@ -185,18 +185,6 @@ struct cpu_pooling
} }
}; };
struct cpu_transpose
{
transpose op;
std::string name() const { return "cpu::transpose"; }
shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); }
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
return {output_shape, std::move(args.front().data)};
}
};
struct cpu_contiguous struct cpu_contiguous
{ {
contiguous op; contiguous op;
...@@ -214,18 +202,6 @@ struct cpu_contiguous ...@@ -214,18 +202,6 @@ struct cpu_contiguous
} }
}; };
struct cpu_reshape
{
reshape op;
std::string name() const { return "cpu::reshape"; }
shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); }
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
return {output_shape, std::move(args.front().data)};
}
};
struct cpu_gemm struct cpu_gemm
{ {
gemm op; gemm op;
...@@ -527,9 +503,7 @@ struct cpu_apply ...@@ -527,9 +503,7 @@ struct cpu_apply
apply_map["gemm"] = extend_op<cpu_gemm, gemm>(); apply_map["gemm"] = extend_op<cpu_gemm, gemm>();
apply_map["batch_norm_inference"] = apply_map["batch_norm_inference"] =
extend_op<cpu_batch_norm_inference, batch_norm_inference>(); extend_op<cpu_batch_norm_inference, batch_norm_inference>();
apply_map["reshape"] = extend_op<cpu_reshape, reshape>();
apply_map["contiguous"] = extend_op<cpu_contiguous, contiguous>(); apply_map["contiguous"] = extend_op<cpu_contiguous, contiguous>();
apply_map["transpose"] = extend_op<cpu_transpose, transpose>();
apply_map["identity"] = simple_op<cpu_unary<identity_op>>(); apply_map["identity"] = simple_op<cpu_unary<identity_op>>();
apply_map["tanh"] = simple_op<cpu_unary<tanh_op>>(); apply_map["tanh"] = simple_op<cpu_unary<tanh_op>>();
......
...@@ -183,22 +183,6 @@ struct miopen_gemm ...@@ -183,22 +183,6 @@ struct miopen_gemm
} }
}; };
struct miopen_transpose
{
transpose op;
std::string name() const { return "gpu::transpose"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2);
return op.compute_shape({inputs.at(0)});
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
return {output_shape, std::move(args.front().data)};
}
};
struct miopen_contiguous struct miopen_contiguous
{ {
contiguous op; contiguous op;
...@@ -271,10 +255,6 @@ struct miopen_apply ...@@ -271,10 +255,6 @@ struct miopen_apply
{ {
apply_gemm(it); apply_gemm(it);
} }
else if(it->op.name() == "transpose")
{
apply_transpose(it);
}
else if(it->op.name() == "contiguous") else if(it->op.name() == "contiguous")
{ {
apply_contiguous(it); apply_contiguous(it);
...@@ -346,13 +326,6 @@ struct miopen_apply ...@@ -346,13 +326,6 @@ struct miopen_apply
ins, miopen_gemm{op}, ins->arguments.at(0), ins->arguments.at(1), output); ins, miopen_gemm{op}, ins->arguments.at(0), ins->arguments.at(1), output);
} }
void apply_transpose(instruction_ref ins)
{
auto&& op = any_cast<transpose>(ins->op);
auto output = insert_allocation(ins, ins->result);
prog->replace_instruction(ins, miopen_transpose{op}, ins->arguments.at(0), output);
}
void apply_contiguous(instruction_ref ins) void apply_contiguous(instruction_ref ins)
{ {
auto&& op = any_cast<contiguous>(ins->op); auto&& op = any_cast<contiguous>(ins->op);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment