Commit 1b0de8d5 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

separte this pull request into ajust_gpu_allocation and eliminate_contiguous processing.

parent c46b0432
...@@ -21,7 +21,7 @@ struct unary ...@@ -21,7 +21,7 @@ struct unary
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1); check_shapes{inputs}.has(1);
return {inputs.at(0).type(), inputs.at(0).lens()}; return inputs.at(0);
} }
}; };
......
...@@ -593,9 +593,9 @@ struct cpu_unary ...@@ -593,9 +593,9 @@ struct cpu_unary
{ {
Op op; Op op;
std::string name() const { return op.name(); } std::string name() const { return op.name(); }
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
return {inputs.front().type(), inputs.front().lens()}; return {inputs.front().type(), inputs.front().lens()};
} }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
...@@ -603,16 +603,7 @@ struct cpu_unary ...@@ -603,16 +603,7 @@ struct cpu_unary
argument result{output_shape}; argument result{output_shape};
result.visit([&](auto output) { result.visit([&](auto output) {
args[0].visit([&](auto input) { args[0].visit([&](auto input) {
if(input.get_shape().packed()) std::transform(input.begin(), input.end(), output.begin(), op.fcn());
{
std::transform(input.begin(), input.end(), output.begin(), op.fcn());
}
else
{
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = op.fcn()(input(idx.begin(), idx.end()));
});
}
}); });
}); });
...@@ -786,11 +777,11 @@ struct cpu_binary ...@@ -786,11 +777,11 @@ struct cpu_binary
{ {
Op op; Op op;
std::string name() const { return op.name(); } std::string name() const { return op.name(); }
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
// operator will generate standard output shape return { inputs.front().type(), inputs.front().lens()};
return {inputs.front().type(), inputs.front().lens()};
} }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
......
...@@ -45,7 +45,7 @@ struct unary_device : oper<Derived> ...@@ -45,7 +45,7 @@ struct unary_device : oper<Derived>
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(2); check_shapes{inputs, *this}.has(2);
return {inputs.at(0).type(), inputs.at(0).lens()}; return inputs.at(1);
} }
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
...@@ -63,7 +63,7 @@ struct binary_device : oper<Derived> ...@@ -63,7 +63,7 @@ struct binary_device : oper<Derived>
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(3); check_shapes{inputs, *this}.has(3);
return {inputs.at(0).type(), inputs.at(0).lens()}; return inputs.at(2);
} }
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
......
...@@ -7,7 +7,7 @@ namespace gpu { ...@@ -7,7 +7,7 @@ namespace gpu {
shape miopen_lrn::compute_shape(const std::vector<shape>& inputs) const shape miopen_lrn::compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(2).standard(); check_shapes{inputs, *this}.has(2).not_broadcasted();
return inputs.at(1); return inputs.at(1);
} }
......
...@@ -7,7 +7,7 @@ namespace gpu { ...@@ -7,7 +7,7 @@ namespace gpu {
shape miopen_relu::compute_shape(const std::vector<shape>& inputs) const shape miopen_relu::compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(2).not_broadcasted().not_transposed(); check_shapes{inputs, *this}.has(2).not_broadcasted();
return inputs.at(1); return inputs.at(1);
} }
......
...@@ -7,7 +7,7 @@ namespace gpu { ...@@ -7,7 +7,7 @@ namespace gpu {
shape miopen_tanh::compute_shape(const std::vector<shape>& inputs) const shape miopen_tanh::compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(2).not_broadcasted().not_transposed(); check_shapes{inputs, *this}.has(2).not_broadcasted();
return inputs.at(1); return inputs.at(1);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment