Unverified Commit 767ca0cc authored by mvermeulen's avatar mvermeulen Committed by GitHub
Browse files

Merge pull request #257 from ROCmSoftwarePlatform/equality

Ensure reflect methods for all operators 
parents a713a6d3 82762e8a
......@@ -25,6 +25,13 @@ namespace gpu {
struct hip_logsoftmax
{
op::logsoftmax op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::logsoftmax"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument
......
......@@ -13,6 +13,13 @@ struct context;
struct miopen_lrn
{
shared<lrn_descriptor> ldesc;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return gpu::reflect(self.ldesc.get(), f);
}
std::string name() const { return "gpu::lrn"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument
......
......@@ -162,6 +162,38 @@ inline fused_operator_args make_fused_args()
return make_obj<fused_operator_args>(&miopenCreateOperatorArgs);
}
template <class F>
auto reflect(miopenActivationDescriptor_t ad, F f)
{
assert(ad != nullptr);
miopenActivationMode_t mode = miopenActivationPASTHRU;
double alpha = 0.0;
double beta = 0.0;
double gamma = 0.0;
miopenGetActivationDescriptor(ad, &mode, &alpha, &beta, &gamma);
return pack(f(std::move(mode), "mode"), // NOLINT
f(std::move(alpha), "alpha"), // NOLINT
f(std::move(beta), "beta"), // NOLINT
f(std::move(gamma), "gamma")); // NOLINT
}
template <class F>
auto reflect(miopenLRNDescriptor_t lrnd, F f)
{
assert(lrnd != nullptr);
miopenLRNMode_t mode = miopenLRNWithinChannel;
unsigned int n = 0;
double alpha = 0.0;
double beta = 0.0;
double k = 0.0;
miopenGetLRNDescriptor(lrnd, &mode, &n, &alpha, &beta, &k);
return pack(f(std::move(mode), "mode"), // NOLINT
f(std::move(n), "n"), // NOLINT
f(std::move(alpha), "alpha"), // NOLINT
f(std::move(beta), "beta"), // NOLINT
f(std::move(k), "k")); // NOLINT
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -14,6 +14,12 @@ struct hip_pad
{
op::pad op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::pad"; }
shape compute_shape(std::vector<shape> inputs) const;
argument
......
......@@ -16,6 +16,12 @@ struct miopen_pooling
op::pooling op;
shared<pooling_descriptor> pd;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::pooling"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument
......
......@@ -13,6 +13,13 @@ struct context;
struct miopen_relu
{
shared<activation_descriptor> ad;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return gpu::reflect(self.ad.get(), f);
}
std::string name() const { return "gpu::relu"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument
......
......@@ -13,6 +13,13 @@ struct context;
struct miopen_sigmoid
{
shared<activation_descriptor> ad;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return gpu::reflect(self.ad.get(), f);
}
std::string name() const { return "gpu::sigmoid"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument
......
......@@ -13,6 +13,13 @@ struct context;
struct miopen_softmax
{
op::softmax op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::softmax"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument
......
......@@ -13,6 +13,13 @@ struct context;
struct miopen_tanh
{
shared<activation_descriptor> ad;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return gpu::reflect(self.ad.get(), f);
}
std::string name() const { return "gpu::tanh"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument
......
......@@ -14,6 +14,13 @@ struct hip_load_literal
{
shape s;
std::size_t n = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.s, "shape"), f(self.n, "id"));
}
std::string name() const { return "hip::load_literal"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
......
......@@ -20,6 +20,13 @@ struct eliminate_allocation_target
struct allocate
{
migraphx::shape s{};
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::pack(f(self.s, "shape"));
}
std::string name() const { return "allocate"; }
migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
{
......
......@@ -10,6 +10,13 @@ struct concat
{
concat(std::size_t axis) { op.axis = axis; }
migraphx::op::concat op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "eliminate_concat::concat"; }
migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
{
......@@ -51,6 +58,13 @@ struct eliminate_concat_target
struct allocate
{
migraphx::shape s{};
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::pack(f(self.s, "shape"));
}
std::string name() const { return "allocate"; }
migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
{
......
......@@ -58,7 +58,7 @@ TEST_CASE(tanh_shape)
if(ins->name() == "hip::allocate")
{
migraphx::shape new_s{migraphx::shape::float_type, {3, 2}, {1, 3}};
migraphx::instruction::replace(ins, ins->get_operator(), new_s, ins->inputs());
ins->replace(migraphx::gpu::hip_allocate{new_s});
}
}
EXPECT(p1 != p2);
......
......@@ -18,6 +18,13 @@ struct memory_coloring_target
struct allocate
{
migraphx::shape s{};
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::pack(f(self.s, "shape"));
}
std::string name() const { return "allocate"; }
migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
{
......
......@@ -29,12 +29,12 @@ TEST_CASE(basic_graph_test)
EXPECT(migraphx::contains(test, "\"@0\"[label=\"@literal\"]"));
EXPECT(migraphx::contains(test, "\"y\"[label=\"@param:y\"]"));
EXPECT(migraphx::contains(test, "\"x\"[label=\"@param:x\"]"));
EXPECT(migraphx::contains(test, "\"@3\"[label=\"sum\"]"));
EXPECT(migraphx::contains(test, "\"@4\"[label=\"sum\"]"));
EXPECT(migraphx::contains(test, "\"x\" -> \"@3\""));
EXPECT(migraphx::contains(test, "\"y\" -> \"@3\""));
EXPECT(migraphx::contains(test, "\"@3\" -> \"@4\""));
EXPECT(migraphx::contains(test, "\"@0\" -> \"@4\""));
EXPECT(migraphx::contains(test, "\"@1\"[label=\"sum\"]"));
EXPECT(migraphx::contains(test, "\"@2\"[label=\"sum\"]"));
EXPECT(migraphx::contains(test, "\"x\" -> \"@1\""));
EXPECT(migraphx::contains(test, "\"y\" -> \"@1\""));
EXPECT(migraphx::contains(test, "\"@1\" -> \"@2\""));
EXPECT(migraphx::contains(test, "\"@0\" -> \"@2\""));
EXPECT(migraphx::contains(test, "[label=\"int64_type, {1}, {0}\"]"));
}
......
......@@ -69,7 +69,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
{
os << x.name();
char delim = '[';
reflect_each(x, [&](auto& y, auto name) {
reflect_each(x, [&](auto&& y, auto name) {
os << delim;
os << name << "=";
stream_write_value(os, y);
......@@ -87,6 +87,8 @@ namespace operation_equal {
template <class T, class U>
auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
{
static_assert(is_reflectable<T>{} or sizeof(T) <= 1,
"Missing equality operator or reflect method.");
if(x.name() != y.name())
return false;
const auto& yy = any_cast<T>(y);
......@@ -175,7 +177,7 @@ auto is_context_free_op(const T& x) -> decltype(is_context_free_op(
}
template <class T>
int output_alias_op(rank<0>, const T&, const std::vector<shape>&)
std::ptrdiff_t output_alias_op(rank<0>, const T&, const std::vector<shape>&)
{
return -1;
}
......@@ -188,7 +190,7 @@ auto output_alias_op(rank<1>, const T& x, const std::vector<shape>& shapes)
}
template <class T>
int output_alias_op(const T& x, const std::vector<shape>& shapes)
std::ptrdiff_t output_alias_op(const T& x, const std::vector<shape>& shapes)
{
return output_alias_op(rank<1>{}, x, shapes);
}
......@@ -238,7 +240,7 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{},
virtual('is_context_free', returns = 'bool', const = True, default = 'is_context_free_op'),
virtual('has_finalize', returns = 'bool', const = True, default = 'has_finalize_op'),
virtual('output_alias',
returns = 'int',
returns = 'std::ptrdiff_t',
input = 'const std::vector<shape>&',
const = True,
default = 'output_alias_op'),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment