Commit bb5f1eb6 authored by Paul's avatar Paul
Browse files

s/miopen_context/context

parent f0861316
#ifndef MIGRAPH_GUARD_RTGLIB_AUTO_ANY_CAST_HPP
#define MIGRAPH_GUARD_RTGLIB_AUTO_ANY_CAST_HPP
namespace migraph {
namespace detail {
template<class U>
void any_cast() {}
template<class T>
struct auto_any_caster
{
T& x;
template <class U>
operator U&()
{
return any_cast<U>(x);
}
operator T&()
{
return x;
}
};
}
template<class T>
detail::auto_any_caster<T> auto_any_cast(T& x)
{
return {x};
}
} // namespace migraph
#endif
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
namespace migraph { namespace migraph {
namespace miopen { namespace miopen {
struct miopen_context struct context
{ {
shared<miopen_handle> handle; shared<miopen_handle> handle;
shared<rocblas_handle_ptr> rbhandle; shared<rocblas_handle_ptr> rbhandle;
......
...@@ -9,8 +9,8 @@ namespace miopen { ...@@ -9,8 +9,8 @@ namespace miopen {
struct target struct target
{ {
std::string name() const; std::string name() const;
std::vector<pass> get_passes(context& ctx) const; std::vector<pass> get_passes(migraph::context& ctx) const;
context get_context() const; migraph::context get_context() const;
}; };
} // namespace miopen } // namespace miopen
......
...@@ -25,7 +25,7 @@ struct miopen_convolution ...@@ -25,7 +25,7 @@ struct miopen_convolution
check_shapes{inputs, *this}.has(3); check_shapes{inputs, *this}.has(3);
return op.compute_shape({inputs.at(0), inputs.at(1)}); return op.compute_shape({inputs.at(0), inputs.at(1)});
} }
argument compute(miopen_context& ctx, shape output_shape, std::vector<argument> args) const argument compute(context& ctx, shape output_shape, std::vector<argument> args) const
{ {
auto x_desc = make_tensor(args[0].get_shape()); auto x_desc = make_tensor(args[0].get_shape());
auto w_desc = make_tensor(args[1].get_shape()); auto w_desc = make_tensor(args[1].get_shape());
...@@ -76,7 +76,7 @@ struct miopen_pooling ...@@ -76,7 +76,7 @@ struct miopen_pooling
check_shapes{inputs, *this}.has(2); check_shapes{inputs, *this}.has(2);
return op.compute_shape({inputs.at(1)}); return op.compute_shape({inputs.at(1)});
} }
argument compute(miopen_context& ctx, shape output_shape, std::vector<argument> args) const argument compute(context& ctx, shape output_shape, std::vector<argument> args) const
{ {
auto x_desc = make_tensor(args[0].get_shape()); auto x_desc = make_tensor(args[0].get_shape());
auto y_desc = make_tensor(output_shape); auto y_desc = make_tensor(output_shape);
...@@ -108,7 +108,7 @@ struct miopen_add ...@@ -108,7 +108,7 @@ struct miopen_add
return inputs.at(0); return inputs.at(0);
} }
argument compute(miopen_context& ctx, shape output_shape, std::vector<argument> args) const argument compute(context& ctx, shape output_shape, std::vector<argument> args) const
{ {
if(args[1].get_shape().broadcasted()) if(args[1].get_shape().broadcasted())
{ {
...@@ -154,7 +154,7 @@ struct miopen_gemm ...@@ -154,7 +154,7 @@ struct miopen_gemm
check_shapes{inputs, *this}.has(3); check_shapes{inputs, *this}.has(3);
return op.compute_shape({inputs.at(0), inputs.at(1)}); return op.compute_shape({inputs.at(0), inputs.at(1)});
} }
argument compute(miopen_context& ctx, shape output_shape, std::vector<argument> args) const argument compute(context& ctx, shape output_shape, std::vector<argument> args) const
{ {
float alpha = 1.0f; float alpha = 1.0f;
float beta = 0.0f; float beta = 0.0f;
...@@ -192,7 +192,7 @@ struct miopen_relu ...@@ -192,7 +192,7 @@ struct miopen_relu
return inputs.at(1); return inputs.at(1);
} }
argument compute(miopen_context& ctx, shape output_shape, std::vector<argument> args) const argument compute(context& ctx, shape output_shape, std::vector<argument> args) const
{ {
float alpha = 1, beta = 0; float alpha = 1, beta = 0;
auto x_desc = make_tensor(args[0].get_shape()); auto x_desc = make_tensor(args[0].get_shape());
...@@ -216,7 +216,7 @@ struct miopen_apply ...@@ -216,7 +216,7 @@ struct miopen_apply
void apply() void apply()
{ {
prog->insert_instruction(prog->begin(), check_context<miopen_context>{}); prog->insert_instruction(prog->begin(), check_context<context>{});
for(auto it = prog->begin(); it != prog->end(); it++) for(auto it = prog->begin(); it != prog->end(); it++)
{ {
if(it->op.name() == "convolution") if(it->op.name() == "convolution")
......
...@@ -6,13 +6,13 @@ ...@@ -6,13 +6,13 @@
namespace migraph { namespace migraph {
namespace miopen { namespace miopen {
std::vector<pass> target::get_passes(context&) const { return {lowering{}, write_literals{}}; } std::vector<pass> target::get_passes(migraph::context&) const { return {lowering{}, write_literals{}}; }
std::string target::name() const { return "miopen"; } std::string target::name() const { return "miopen"; }
context target::get_context() const migraph::context target::get_context() const
{ {
return miopen_context{share(make_obj<miopen_handle>(&miopenCreate)), return context{share(make_obj<miopen_handle>(&miopenCreate)),
share(create_rocblas_handle_ptr())}; share(create_rocblas_handle_ptr())};
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment