"vscode:/vscode.git/clone" did not exist on "1a665a63b09a83ab06317f8acfe7e7f75037c5ab"
Commit 3ed217c9 authored by Paul's avatar Paul
Browse files

Ensure reflect methods for all operators

parent b2051bbc
...@@ -13,6 +13,13 @@ struct context; ...@@ -13,6 +13,13 @@ struct context;
struct miopen_relu struct miopen_relu
{ {
shared<activation_descriptor> ad; shared<activation_descriptor> ad;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return gpu::reflect(self.ad.get(), f);
}
std::string name() const { return "gpu::relu"; } std::string name() const { return "gpu::relu"; }
shape compute_shape(const std::vector<shape>& inputs) const; shape compute_shape(const std::vector<shape>& inputs) const;
argument argument
......
...@@ -13,6 +13,13 @@ struct context; ...@@ -13,6 +13,13 @@ struct context;
struct miopen_sigmoid struct miopen_sigmoid
{ {
shared<activation_descriptor> ad; shared<activation_descriptor> ad;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return gpu::reflect(self.ad.get(), f);
}
std::string name() const { return "gpu::sigmoid"; } std::string name() const { return "gpu::sigmoid"; }
shape compute_shape(const std::vector<shape>& inputs) const; shape compute_shape(const std::vector<shape>& inputs) const;
argument argument
......
...@@ -13,6 +13,13 @@ struct context; ...@@ -13,6 +13,13 @@ struct context;
struct miopen_softmax struct miopen_softmax
{ {
op::softmax op; op::softmax op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::softmax"; } std::string name() const { return "gpu::softmax"; }
shape compute_shape(const std::vector<shape>& inputs) const; shape compute_shape(const std::vector<shape>& inputs) const;
argument argument
......
...@@ -13,6 +13,13 @@ struct context; ...@@ -13,6 +13,13 @@ struct context;
struct miopen_tanh struct miopen_tanh
{ {
shared<activation_descriptor> ad; shared<activation_descriptor> ad;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return gpu::reflect(self.ad.get(), f);
}
std::string name() const { return "gpu::tanh"; } std::string name() const { return "gpu::tanh"; }
shape compute_shape(const std::vector<shape>& inputs) const; shape compute_shape(const std::vector<shape>& inputs) const;
argument argument
......
...@@ -14,6 +14,13 @@ struct hip_load_literal ...@@ -14,6 +14,13 @@ struct hip_load_literal
{ {
shape s; shape s;
std::size_t n = 0; std::size_t n = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.s, "shape"), f(self.n, "id"));
}
std::string name() const { return "hip::load_literal"; } std::string name() const { return "hip::load_literal"; }
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
......
...@@ -20,6 +20,13 @@ struct eliminate_allocation_target ...@@ -20,6 +20,13 @@ struct eliminate_allocation_target
struct allocate struct allocate
{ {
migraphx::shape s{}; migraphx::shape s{};
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::pack(f(self.s, "shape"));
}
std::string name() const { return "allocate"; } std::string name() const { return "allocate"; }
migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
{ {
......
...@@ -10,6 +10,13 @@ struct concat ...@@ -10,6 +10,13 @@ struct concat
{ {
concat(std::size_t axis) { op.axis = axis; } concat(std::size_t axis) { op.axis = axis; }
migraphx::op::concat op; migraphx::op::concat op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "eliminate_concat::concat"; } std::string name() const { return "eliminate_concat::concat"; }
migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
{ {
...@@ -51,6 +58,13 @@ struct eliminate_concat_target ...@@ -51,6 +58,13 @@ struct eliminate_concat_target
struct allocate struct allocate
{ {
migraphx::shape s{}; migraphx::shape s{};
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::pack(f(self.s, "shape"));
}
std::string name() const { return "allocate"; } std::string name() const { return "allocate"; }
migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
{ {
......
...@@ -18,6 +18,13 @@ struct memory_coloring_target ...@@ -18,6 +18,13 @@ struct memory_coloring_target
struct allocate struct allocate
{ {
migraphx::shape s{}; migraphx::shape s{};
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::pack(f(self.s, "shape"));
}
std::string name() const { return "allocate"; } std::string name() const { return "allocate"; }
migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
{ {
......
...@@ -87,6 +87,7 @@ namespace operation_equal { ...@@ -87,6 +87,7 @@ namespace operation_equal {
template <class T, class U> template <class T, class U>
auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name()) auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
{ {
static_assert(is_reflectable<T>{} or sizeof(T) <= 1, "Missing equality operator or reflect method.");
if(x.name() != y.name()) if(x.name() != y.name())
return false; return false;
const auto& yy = any_cast<T>(y); const auto& yy = any_cast<T>(y);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment