Commit 04ca2e74 authored by wsttiger's avatar wsttiger
Browse files

Merge branch 'contiguous_on_cpu' into hip_contiguous

parents 5f68a283 e29a2396
......@@ -173,6 +173,43 @@ struct miopen_gemm
}
};
struct miopen_transpose
{
transpose op;
std::string name() const { return "miopen::transpose"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2);
return op.compute_shape({inputs.at(0)});
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
return {output_shape, std::move(args.front().data)};
}
};
struct miopen_contiguous
{
contiguous op;
std::string name() const { return "miopen::contiguous"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2);
return op.compute_shape({inputs.at(0)});
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, from_gpu(args[0]))([&](auto output, auto input) {
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = input(idx.begin(), idx.end());
});
});
return to_gpu(result);
}
};
struct miopen_relu
{
shared<activation_descriptor> ad;
......@@ -231,6 +268,14 @@ struct miopen_apply
{
apply_gemm(it);
}
else if(it->op.name() == "transpose")
{
apply_transpose(it);
}
else if(it->op.name() == "contiguous")
{
apply_contiguous(it);
}
}
}
......@@ -297,6 +342,20 @@ struct miopen_apply
prog->replace_instruction(
ins, miopen_gemm{op}, ins->arguments.at(0), ins->arguments.at(1), output);
}
void apply_transpose(instruction_ref ins)
{
auto&& op = any_cast<transpose>(ins->op);
auto output = insert_allocation(ins, ins->result);
prog->replace_instruction(ins, miopen_transpose{op}, ins->arguments.at(0), output);
}
void apply_contiguous(instruction_ref ins)
{
auto&& op = any_cast<contiguous>(ins->op);
auto output = insert_allocation(ins, ins->result);
prog->replace_instruction(ins, miopen_contiguous{op}, ins->arguments.at(0), output);
}
};
struct miopen_pass
......
......@@ -158,6 +158,47 @@ struct test_gemm
}
};
struct test_contiguous
{
migraph::program create_program() const
{
migraph::program p;
migraph::shape s{migraph::shape::float_type, {4, 4, 4, 3}, {48, 4, 1, 16}};
auto x = p.add_parameter("x", s);
p.add_instruction(migraph::contiguous{}, x);
return p;
}
migraph::program::parameter_map create_params() const
{
migraph::program::parameter_map m;
m["x"] =
migraph::generate_argument({migraph::shape::float_type, {4, 4, 4, 3}, {48, 4, 1, 16}});
return m;
}
};
struct test_transpose
{
migraph::program create_program() const
{
migraph::program p;
migraph::shape s{migraph::shape::float_type, {4, 3, 4, 4}};
auto x = p.add_parameter("x", s);
std::vector<int64_t> perm = {0, 2, 3, 1};
auto l = p.add_instruction(migraph::transpose{perm}, x);
p.add_instruction(migraph::contiguous{}, l);
return p;
}
migraph::program::parameter_map create_params() const
{
migraph::program::parameter_map m;
m["x"] = migraph::generate_argument({migraph::shape::float_type, {4, 3, 4, 4}});
return m;
}
};
int main()
{
verify_program<test_add>();
......@@ -165,4 +206,6 @@ int main()
verify_program<test_conv_relu>();
verify_program<test_conv_pooling>();
verify_program<test_gemm>();
verify_program<test_contiguous>();
verify_program<test_transpose>();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment