Commit bb5f1eb6 authored by Paul's avatar Paul
Browse files

s/miopen_context/context

parent f0861316
#ifndef MIGRAPH_GUARD_RTGLIB_AUTO_ANY_CAST_HPP
#define MIGRAPH_GUARD_RTGLIB_AUTO_ANY_CAST_HPP
namespace migraph {
namespace detail {
template<class U>
void any_cast() {}
template<class T>
struct auto_any_caster
{
T& x;
template <class U>
operator U&()
{
return any_cast<U>(x);
}
operator T&()
{
return x;
}
};
}
template<class T>
detail::auto_any_caster<T> auto_any_cast(T& x)
{
return {x};
}
} // namespace migraph
#endif
......@@ -7,7 +7,7 @@
namespace migraph {
namespace miopen {
struct miopen_context
struct context
{
shared<miopen_handle> handle;
shared<rocblas_handle_ptr> rbhandle;
......
......@@ -9,8 +9,8 @@ namespace miopen {
struct target
{
std::string name() const;
std::vector<pass> get_passes(context& ctx) const;
context get_context() const;
std::vector<pass> get_passes(migraph::context& ctx) const;
migraph::context get_context() const;
};
} // namespace miopen
......
......@@ -25,7 +25,7 @@ struct miopen_convolution
check_shapes{inputs, *this}.has(3);
return op.compute_shape({inputs.at(0), inputs.at(1)});
}
argument compute(miopen_context& ctx, shape output_shape, std::vector<argument> args) const
argument compute(context& ctx, shape output_shape, std::vector<argument> args) const
{
auto x_desc = make_tensor(args[0].get_shape());
auto w_desc = make_tensor(args[1].get_shape());
......@@ -76,7 +76,7 @@ struct miopen_pooling
check_shapes{inputs, *this}.has(2);
return op.compute_shape({inputs.at(1)});
}
argument compute(miopen_context& ctx, shape output_shape, std::vector<argument> args) const
argument compute(context& ctx, shape output_shape, std::vector<argument> args) const
{
auto x_desc = make_tensor(args[0].get_shape());
auto y_desc = make_tensor(output_shape);
......@@ -108,7 +108,7 @@ struct miopen_add
return inputs.at(0);
}
argument compute(miopen_context& ctx, shape output_shape, std::vector<argument> args) const
argument compute(context& ctx, shape output_shape, std::vector<argument> args) const
{
if(args[1].get_shape().broadcasted())
{
......@@ -154,7 +154,7 @@ struct miopen_gemm
check_shapes{inputs, *this}.has(3);
return op.compute_shape({inputs.at(0), inputs.at(1)});
}
argument compute(miopen_context& ctx, shape output_shape, std::vector<argument> args) const
argument compute(context& ctx, shape output_shape, std::vector<argument> args) const
{
float alpha = 1.0f;
float beta = 0.0f;
......@@ -192,7 +192,7 @@ struct miopen_relu
return inputs.at(1);
}
argument compute(miopen_context& ctx, shape output_shape, std::vector<argument> args) const
argument compute(context& ctx, shape output_shape, std::vector<argument> args) const
{
float alpha = 1, beta = 0;
auto x_desc = make_tensor(args[0].get_shape());
......@@ -216,7 +216,7 @@ struct miopen_apply
void apply()
{
prog->insert_instruction(prog->begin(), check_context<miopen_context>{});
prog->insert_instruction(prog->begin(), check_context<context>{});
for(auto it = prog->begin(); it != prog->end(); it++)
{
if(it->op.name() == "convolution")
......
......@@ -6,13 +6,13 @@
namespace migraph {
namespace miopen {
std::vector<pass> target::get_passes(context&) const { return {lowering{}, write_literals{}}; }
std::vector<pass> target::get_passes(migraph::context&) const { return {lowering{}, write_literals{}}; }
std::string target::name() const { return "miopen"; }
context target::get_context() const
migraph::context target::get_context() const
{
return miopen_context{share(make_obj<miopen_handle>(&miopenCreate)),
return context{share(make_obj<miopen_handle>(&miopenCreate)),
share(create_rocblas_handle_ptr())};
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment