Commit 0083b90b authored by Paul's avatar Paul
Browse files

Add reflect methods

parent 2c6acfde
......@@ -9,6 +9,7 @@
#include <utility>
#include <migraph/shape.hpp>
#include <migraph/reflect.hpp>
#include <migraph/streamutils.hpp>
#include <migraph/argument.hpp>
#include <migraph/context.hpp>
#include <migraph/auto_any_cast.hpp>
......@@ -58,7 +59,8 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
char delim = '[';
reflect_each(x, [&](auto& y, auto name, auto&&...) {
os << delim;
os << name << "=" << y;
os << name << "=";
stream_write_value(os, y);
delim = ',';
});
if(delim == ',')
......
......@@ -35,7 +35,15 @@ struct batch_norm_inference
bn_infer_mode_t bn_mode = spatial;
bool is_test = false;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(
f(self.epsilon, "epsilon"),
f(self.momentum, "momentum"),
f(self.bn_mode, "bn_mode")
);
}
shape compute_shape(std::vector<shape> inputs) const
{
......@@ -56,6 +64,18 @@ struct convolution
valid
};
padding_mode_t padding_mode = default_;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(
f(self.padding, "padding"),
f(self.stride, "stride"),
f(self.dilation, "dilation"),
f(self.padding_mode, "padding_mode")
);
}
std::string name() const { return "convolution"; }
shape compute_shape(std::vector<shape> inputs) const
{
......@@ -110,16 +130,6 @@ struct convolution
MIGRAPH_THROW("Invalid padding mode");
}
}
friend std::ostream& operator<<(std::ostream& os, const convolution& op)
{
os << op.name() << "[";
os << "padding={" << stream_range(op.padding) << "}, ";
os << "stride={" << stream_range(op.stride) << "}, ";
os << "dilation={" << stream_range(op.dilation) << "}";
os << "]";
return os;
}
};
struct im2col
......@@ -133,6 +143,18 @@ struct im2col
same,
valid
};
padding_mode_t padding_mode = default_;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(
f(self.padding, "padding"),
f(self.stride, "stride"),
f(self.dilation, "dilation"),
f(self.padding_mode, "padding_mode")
);
}
std::string name() const { return "im2col"; }
......@@ -168,6 +190,18 @@ struct pooling
std::array<std::size_t, 2> padding = {{0, 0}};
std::array<std::size_t, 2> stride = {{1, 1}};
std::array<std::size_t, 2> lengths = {{1, 1}};
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(
f(self.mode, "mode"),
f(self.padding, "padding"),
f(self.stride, "stride"),
f(self.lengths, "lengths")
);
}
std::string name() const { return "pooling"; }
shape compute_shape(std::vector<shape> inputs) const
......@@ -196,16 +230,6 @@ struct pooling
1)),
}};
}
friend std::ostream& operator<<(std::ostream& os, const pooling& op)
{
os << op.name() << "[";
os << "padding={" << stream_range(op.padding) << "}, ";
os << "stride={" << stream_range(op.stride) << "}, ";
os << "lengths={" << stream_range(op.lengths) << "}";
os << "]";
return os;
}
};
struct activation
......@@ -227,6 +251,15 @@ struct activation
struct transpose
{
std::vector<int64_t> dims;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(
f(self.dims, "dims")
);
}
std::string name() const { return "transpose"; }
shape compute_shape(std::vector<shape> inputs) const
{
......@@ -258,13 +291,6 @@ struct transpose
{
return {std::move(output_shape), std::move(args.front().data)};
}
friend std::ostream& operator<<(std::ostream& os, const transpose& op)
{
os << op.name() << "[";
os << "dims={" << stream_range(op.dims) << "}";
os << "]";
return os;
}
};
struct contiguous
......@@ -288,6 +314,17 @@ struct slice
std::vector<int64_t> axes;
std::vector<int64_t> starts;
std::vector<int64_t> ends;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(
f(self.axes, "axes"),
f(self.starts, "starts"),
f(self.ends, "ends")
);
}
std::string name() const { return "slice"; }
auto fix_index(const std::vector<std::size_t>& lens, std::size_t axis, int64_t index) const
......@@ -360,6 +397,15 @@ struct slice
struct squeeze
{
std::vector<int64_t> axes;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(
f(self.axes, "axes")
);
}
std::string name() const { return "squeeze"; }
shape compute_shape(std::vector<shape> inputs) const
{
......@@ -400,6 +446,15 @@ struct squeeze
struct unsqueeze
{
std::vector<int64_t> axes;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(
f(self.axes, "axes")
);
}
std::string name() const { return "unsqueeze"; }
shape compute_shape(std::vector<shape> inputs) const
{
......@@ -431,6 +486,15 @@ struct unsqueeze
struct reshape
{
std::vector<int64_t> dims;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(
f(self.dims, "dims")
);
}
std::string name() const { return "reshape"; }
shape compute_shape(std::vector<shape> inputs) const
{
......@@ -470,19 +534,22 @@ struct reshape
{
return {std::move(output_shape), std::move(args.front().data)};
}
friend std::ostream& operator<<(std::ostream& os, const reshape& op)
{
os << op.name() << "[";
os << "dims={" << stream_range(op.dims) << "}";
os << "]";
return os;
}
};
struct gemm
{
float alpha = 1.0;
float beta = 0.0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(
f(self.alpha, "alpha"),
f(self.beta, "beta")
);
}
std::string name() const { return "gemm"; }
shape compute_shape(std::vector<shape> inputs) const
{
......@@ -496,13 +563,6 @@ struct gemm
to_string_range(b.lens()) + "}");
return {t, {a.lens()[0], b.lens()[1]}};
}
friend std::ostream& operator<<(std::ostream& os, const gemm& op)
{
os << op.name() << "[";
os << "]";
return os;
}
};
struct unary
......@@ -587,6 +647,15 @@ struct softmax
struct flatten
{
uint64_t axis = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(
f(self.axis, "axis")
);
}
std::string name() const { return "flatten"; }
shape compute_shape(std::vector<shape> inputs) const
{
......@@ -607,17 +676,19 @@ struct flatten
{
return {std::move(output_shape), std::move(args.front().data)};
}
friend std::ostream& operator<<(std::ostream& os, const flatten& op)
{
os << op.name() << "[";
os << "axis=" << op.axis;
os << "]";
return os;
}
};
struct broadcast
{
uint64_t axis = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(
f(self.axis, "axis")
);
}
shape broadcast_shape;
std::string name() const { return "broadcast"; }
shape compute_shape(std::vector<shape> inputs) const
......@@ -649,18 +720,10 @@ struct broadcast
{
return {std::move(output_shape), std::move(args.at(0).data)};
}
friend std::ostream& operator<<(std::ostream& os, const broadcast& op)
{
os << op.name() << "[";
os << "axis=" << op.axis;
os << "]";
return os;
}
};
struct binary
{
uint64_t broadcast = 0;
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(2).same_type().same_dims();
......@@ -692,6 +755,16 @@ struct load
{
shape s;
std::size_t offset = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(
f(self.s, "shape"),
f(self.offset, "offset")
);
}
std::string name() const { return "load"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
......@@ -707,6 +780,15 @@ struct load
struct outline
{
shape s;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(
f(self.s, "shape")
);
}
std::string name() const { return "outline"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
......
......@@ -41,7 +41,7 @@ void reflect_each(T& x, F f)
{
return reflect(x, [](auto&& y, auto... ys) { return pack(std::ref(y), ys...); })(
[&](auto&&... xs) {
each_args([&](auto p) { p([&](auto&& y, auto... ys) { f(y, ys...); }); }, xs...);
each_args([&](auto p) { p([&](auto&& y, auto... ys) { f(y.get(), ys...); }); }, xs...);
});
}
......
......@@ -3,6 +3,7 @@
#include <ostream>
#include <algorithm>
#include <migraph/rank.hpp>
namespace migraph {
......@@ -31,6 +32,27 @@ inline stream_range_container<Range> stream_range(const Range& r)
return {r};
}
namespace detail {
template<class Range>
auto stream_write_value_impl(rank<1>, std::ostream& os, const Range& r) -> decltype(r.begin(), r.end(), void())
{
os << stream_range(r);
}
template<class T>
void stream_write_value_impl(rank<0>, std::ostream& os, const T& x)
{
os << x;
}
} // namespace detail
template<class T>
void stream_write_value(std::ostream& os, const T& x)
{
detail::stream_write_value_impl(rank<1>{}, os, x);
}
} // namespace migraph
#endif
......@@ -255,7 +255,7 @@ struct onnx_parser
? op::batch_norm_inference::spatial
: op::batch_norm_inference::per_activation;
}
op::batch_norm_inference op{epsilon, momentum, bn_mode, is_test};
op::batch_norm_inference op{epsilon, momentum, bn_mode};
return prog.add_instruction(op, std::move(args));
}
......
......@@ -26,19 +26,18 @@ struct miopen_convolution
shared<convolution_descriptor> cd;
miopenConvFwdAlgorithm_t algo{};
template <class Self, class F>
static auto reflect(Self& self, F f)
{
// TODO: Add algo
return op::convolution::reflect(self.op, f);
}
std::string name() const { return "gpu::convolution"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
shape compile(context& ctx, const shape& output_shape, std::vector<instruction_ref> inputs);
friend std::ostream& operator<<(std::ostream& os, const miopen_convolution& self)
{
os << self.name() << "[";
os << self.op << ", ";
os << "algo=" << self.algo;
os << "]";
return os;
}
};
} // namespace gpu
......
......@@ -54,7 +54,7 @@ void pytorch_conv_bn_relu_maxpool()
auto l3 = p.add_instruction(migraph::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraph::op::broadcast{axis, l3->get_shape()}, l2);
auto l5 = p.add_instruction(migraph::op::add{}, l3, l4);
auto l6 = p.add_instruction(migraph::op::batch_norm_inference{}, l5, p3, p4, p5, p6);
auto l6 = p.add_instruction(migraph::op::batch_norm_inference{1.0e-5f}, l5, p3, p4, p5, p6);
auto l7 = p.add_instruction(migraph::op::activation{"relu"}, l6);
p.add_instruction(migraph::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l7);
......
......@@ -9,6 +9,7 @@
#include <utility>
#include <migraph/shape.hpp>
#include <migraph/reflect.hpp>
#include <migraph/streamutils.hpp>
#include <migraph/argument.hpp>
#include <migraph/context.hpp>
#include <migraph/auto_any_cast.hpp>
......@@ -58,7 +59,8 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
char delim = '[';
reflect_each(x, [&](auto& y, auto name, auto&&...) {
os << delim;
os << name << "=" << y;
os << name << "=";
stream_write_value(os, y);
delim = ',';
});
if(delim == ',')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment