Commit 5856e5d0 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'adjust_gpu_allocation' into ins_fp32_fp16

parents adcd9f39 7c3b9d48
...@@ -21,7 +21,7 @@ struct unary ...@@ -21,7 +21,7 @@ struct unary
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1); check_shapes{inputs}.has(1);
return inputs.at(0); return {inputs.at(0).type(), inputs.at(0).lens()};
} }
}; };
......
...@@ -593,15 +593,29 @@ struct cpu_unary ...@@ -593,15 +593,29 @@ struct cpu_unary
{ {
Op op; Op op;
std::string name() const { return op.name(); } std::string name() const { return op.name(); }
shape compute_shape(const std::vector<shape>& inputs) const { return inputs.front(); } shape compute_shape(const std::vector<shape>& inputs) const
{
return {inputs.front().type(), inputs.front().lens()};
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
result.visit([&](auto output) { result.visit([&](auto output) {
args[0].visit([&](auto input) { args[0].visit([&](auto input) {
std::transform(input.begin(), input.end(), output.begin(), op.fcn()); if(input.get_shape().packed())
{
std::transform(input.begin(), input.end(), output.begin(), op.fcn());
}
else
{
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = op.fcn()(input(idx.begin(), idx.end()));
});
}
}); });
}); });
return result; return result;
} }
}; };
...@@ -784,7 +798,11 @@ struct cpu_binary ...@@ -784,7 +798,11 @@ struct cpu_binary
{ {
Op op; Op op;
std::string name() const { return op.name(); } std::string name() const { return op.name(); }
shape compute_shape(const std::vector<shape>& inputs) const { return inputs.front(); } shape compute_shape(const std::vector<shape>& inputs) const
{
// operator will generate standard output shape
return {inputs.front().type(), inputs.front().lens()};
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
......
...@@ -9,14 +9,10 @@ namespace gpu { ...@@ -9,14 +9,10 @@ namespace gpu {
void adjust_allocation::apply(program& p) const void adjust_allocation::apply(program& p) const
{ {
std::vector<std::string> ins_names = {"gpu::fp_conversion"};
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
// skip instructions not in the set if(ins->name() == "load")
if(std::find(ins_names.begin(), ins_names.end(), ins->name()) == ins_names.end())
{
continue; continue;
}
auto alias_ins = instruction::get_output_alias(ins, true); auto alias_ins = instruction::get_output_alias(ins, true);
if(alias_ins->name() == "hip::allocate") if(alias_ins->name() == "hip::allocate")
......
...@@ -45,7 +45,7 @@ struct unary_device : oper<Derived> ...@@ -45,7 +45,7 @@ struct unary_device : oper<Derived>
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(2); check_shapes{inputs, *this}.has(2);
return inputs.at(0); return {inputs.at(0).type(), inputs.at(0).lens()};
} }
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
...@@ -63,7 +63,7 @@ struct binary_device : oper<Derived> ...@@ -63,7 +63,7 @@ struct binary_device : oper<Derived>
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(3); check_shapes{inputs, *this}.has(3);
return inputs.at(0); return {inputs.at(0).type(), inputs.at(0).lens()};
} }
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment