Commit 1b0de8d5 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

separte this pull request into ajust_gpu_allocation and eliminate_contiguous processing.

parent c46b0432
......@@ -21,7 +21,7 @@ struct unary
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
return {inputs.at(0).type(), inputs.at(0).lens()};
return inputs.at(0);
}
};
......
......@@ -593,9 +593,9 @@ struct cpu_unary
{
Op op;
std::string name() const { return op.name(); }
shape compute_shape(const std::vector<shape>& inputs) const
{
return {inputs.front().type(), inputs.front().lens()};
shape compute_shape(const std::vector<shape>& inputs) const
{
return {inputs.front().type(), inputs.front().lens()};
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
......@@ -603,16 +603,7 @@ struct cpu_unary
argument result{output_shape};
result.visit([&](auto output) {
args[0].visit([&](auto input) {
if(input.get_shape().packed())
{
std::transform(input.begin(), input.end(), output.begin(), op.fcn());
}
else
{
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = op.fcn()(input(idx.begin(), idx.end()));
});
}
std::transform(input.begin(), input.end(), output.begin(), op.fcn());
});
});
......@@ -786,11 +777,11 @@ struct cpu_binary
{
Op op;
std::string name() const { return op.name(); }
shape compute_shape(const std::vector<shape>& inputs) const
{
// operator will generate standard output shape
return {inputs.front().type(), inputs.front().lens()};
shape compute_shape(const std::vector<shape>& inputs) const
{
return { inputs.front().type(), inputs.front().lens()};
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
......
......@@ -45,7 +45,7 @@ struct unary_device : oper<Derived>
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2);
return {inputs.at(0).type(), inputs.at(0).lens()};
return inputs.at(1);
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
......@@ -63,7 +63,7 @@ struct binary_device : oper<Derived>
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(3);
return {inputs.at(0).type(), inputs.at(0).lens()};
return inputs.at(2);
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
......
......@@ -7,7 +7,7 @@ namespace gpu {
shape miopen_lrn::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2).standard();
check_shapes{inputs, *this}.has(2).not_broadcasted();
return inputs.at(1);
}
......
......@@ -7,7 +7,7 @@ namespace gpu {
shape miopen_relu::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2).not_broadcasted().not_transposed();
check_shapes{inputs, *this}.has(2).not_broadcasted();
return inputs.at(1);
}
......
......@@ -7,7 +7,7 @@ namespace gpu {
shape miopen_tanh::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2).not_broadcasted().not_transposed();
check_shapes{inputs, *this}.has(2).not_broadcasted();
return inputs.at(1);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment