Commit 5856e5d0 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'adjust_gpu_allocation' into ins_fp32_fp16

parents adcd9f39 7c3b9d48
......@@ -21,7 +21,7 @@ struct unary
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
return inputs.at(0);
return {inputs.at(0).type(), inputs.at(0).lens()};
}
};
......
......@@ -593,15 +593,29 @@ struct cpu_unary
{
Op op;
std::string name() const { return op.name(); }
shape compute_shape(const std::vector<shape>& inputs) const { return inputs.front(); }
shape compute_shape(const std::vector<shape>& inputs) const
{
return {inputs.front().type(), inputs.front().lens()};
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
result.visit([&](auto output) {
args[0].visit([&](auto input) {
std::transform(input.begin(), input.end(), output.begin(), op.fcn());
if(input.get_shape().packed())
{
std::transform(input.begin(), input.end(), output.begin(), op.fcn());
}
else
{
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = op.fcn()(input(idx.begin(), idx.end()));
});
}
});
});
return result;
}
};
......@@ -784,7 +798,11 @@ struct cpu_binary
{
Op op;
std::string name() const { return op.name(); }
shape compute_shape(const std::vector<shape>& inputs) const { return inputs.front(); }
shape compute_shape(const std::vector<shape>& inputs) const
{
// operator will generate standard output shape
return {inputs.front().type(), inputs.front().lens()};
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
......
......@@ -9,14 +9,10 @@ namespace gpu {
void adjust_allocation::apply(program& p) const
{
std::vector<std::string> ins_names = {"gpu::fp_conversion"};
for(auto ins : iterator_for(p))
{
// skip instructions not in the set
if(std::find(ins_names.begin(), ins_names.end(), ins->name()) == ins_names.end())
{
if(ins->name() == "load")
continue;
}
auto alias_ins = instruction::get_output_alias(ins, true);
if(alias_ins->name() == "hip::allocate")
......
......@@ -45,7 +45,7 @@ struct unary_device : oper<Derived>
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2);
return inputs.at(0);
return {inputs.at(0).type(), inputs.at(0).lens()};
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
......@@ -63,7 +63,7 @@ struct binary_device : oper<Derived>
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(3);
return inputs.at(0);
return {inputs.at(0).type(), inputs.at(0).lens()};
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment