Commit d22bab64 authored by wsttiger's avatar wsttiger
Browse files

Added op namespace on operators

parents ad0ab357 3d264140
...@@ -15,6 +15,8 @@ migraph::argument from_gpu(migraph::argument arg); ...@@ -15,6 +15,8 @@ migraph::argument from_gpu(migraph::argument arg);
void gpu_sync(); void gpu_sync();
void copy_to_gpu(char* dst, const char* src, std::size_t size);
struct hip_allocate struct hip_allocate
{ {
std::string tag{}; std::string tag{};
...@@ -30,22 +32,6 @@ struct hip_allocate ...@@ -30,22 +32,6 @@ struct hip_allocate
} }
}; };
struct hip_load
{
shape s;
std::size_t offset = 0;
std::string name() const { return "hip::load"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs}.has(1);
return s;
}
argument compute(context&, const shape&, const std::vector<argument>& args) const
{
return {s, args[0].data() + offset};
}
};
struct hip_sync struct hip_sync
{ {
std::string tag{}; std::string tag{};
...@@ -81,8 +67,21 @@ struct hip_write ...@@ -81,8 +67,21 @@ struct hip_write
} }
}; };
struct hip_memcpy
{
std::string name() const { return "hip_memcpy"; }
shape compute_shape(std::vector<shape> inputs) const { return inputs.at(1); }
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
char* dst = args.at(0).data() + offset;
const char* src = args.at(1).data();
std::size_t size = args.at(1).get_shape().bytes();
copy_to_gpu(dst, src, size);
return {std::move(output_shape), dst};
}
std::size_t offset = 0;
};
} // namespace gpu } // namespace gpu
} // namespace migraph } // namespace migraph
#endif #endif
...@@ -44,7 +44,7 @@ inline tensor_descriptor make_tensor(const migraph::shape& s) ...@@ -44,7 +44,7 @@ inline tensor_descriptor make_tensor(const migraph::shape& s)
return t; return t;
} }
inline convolution_descriptor make_conv(const migraph::convolution& op) inline convolution_descriptor make_conv(const migraph::op::convolution& op)
{ {
auto c = make_obj<convolution_descriptor>(&miopenCreateConvolutionDescriptor); auto c = make_obj<convolution_descriptor>(&miopenCreateConvolutionDescriptor);
miopenInitConvolutionDescriptor(c.get(), miopenInitConvolutionDescriptor(c.get(),
...@@ -58,7 +58,7 @@ inline convolution_descriptor make_conv(const migraph::convolution& op) ...@@ -58,7 +58,7 @@ inline convolution_descriptor make_conv(const migraph::convolution& op)
return c; return c;
} }
inline pooling_descriptor make_pooling(const migraph::pooling& op) inline pooling_descriptor make_pooling(const migraph::op::pooling& op)
{ {
miopenPoolingMode_t mode; miopenPoolingMode_t mode;
if(op.mode == "max") if(op.mode == "max")
......
...@@ -22,7 +22,7 @@ namespace gpu { ...@@ -22,7 +22,7 @@ namespace gpu {
struct miopen_pooling struct miopen_pooling
{ {
pooling op; op::pooling op;
shared<pooling_descriptor> pd; shared<pooling_descriptor> pd;
std::string name() const { return "gpu::pooling"; } std::string name() const { return "gpu::pooling"; }
......
...@@ -22,7 +22,7 @@ namespace gpu { ...@@ -22,7 +22,7 @@ namespace gpu {
struct miopen_softmax struct miopen_softmax
{ {
softmax op; op::softmax op;
std::string name() const { return "gpu::softmax"; } std::string name() const { return "gpu::softmax"; }
shape compute_shape(const std::vector<shape>& inputs) const; shape compute_shape(const std::vector<shape>& inputs) const;
argument argument
......
...@@ -12,9 +12,7 @@ struct target ...@@ -12,9 +12,7 @@ struct target
std::vector<pass> get_passes(migraph::context& gctx) const; std::vector<pass> get_passes(migraph::context& gctx) const;
migraph::context get_context() const; migraph::context get_context() const;
}; };
} // namespace gpu } // namespace gpu
} // namespace migraph } // namespace migraph
#endif #endif
...@@ -94,7 +94,7 @@ struct miopen_apply ...@@ -94,7 +94,7 @@ struct miopen_apply
instruction_ref apply_convolution(instruction_ref ins) instruction_ref apply_convolution(instruction_ref ins)
{ {
auto&& op = any_cast<convolution>(ins->get_operator()); auto&& op = any_cast<op::convolution>(ins->get_operator());
auto conv = miopen_convolution{op, make_conv(op)}; auto conv = miopen_convolution{op, make_conv(op)};
auto ws = conv.compile(ctx, ins->get_shape(), ins->inputs()); auto ws = conv.compile(ctx, ins->get_shape(), ins->inputs());
...@@ -108,7 +108,7 @@ struct miopen_apply ...@@ -108,7 +108,7 @@ struct miopen_apply
instruction_ref apply_pooling(instruction_ref ins) instruction_ref apply_pooling(instruction_ref ins)
{ {
auto&& op = any_cast<pooling>(ins->get_operator()); auto&& op = any_cast<op::pooling>(ins->get_operator());
auto pd = make_pooling(op); auto pd = make_pooling(op);
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
...@@ -118,7 +118,7 @@ struct miopen_apply ...@@ -118,7 +118,7 @@ struct miopen_apply
instruction_ref apply_activation(instruction_ref ins) instruction_ref apply_activation(instruction_ref ins)
{ {
auto&& op = any_cast<activation>(ins->get_operator()); auto&& op = any_cast<op::activation>(ins->get_operator());
auto ad = make_relu(); auto ad = make_relu();
if(op.mode == "relu") if(op.mode == "relu")
{ {
...@@ -131,7 +131,7 @@ struct miopen_apply ...@@ -131,7 +131,7 @@ struct miopen_apply
instruction_ref apply_softmax(instruction_ref ins) instruction_ref apply_softmax(instruction_ref ins)
{ {
auto&& op = any_cast<softmax>(ins->get_operator()); auto&& op = any_cast<op::softmax>(ins->get_operator());
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction(ins, miopen_softmax{op}, ins->inputs().at(0), output); return prog->replace_instruction(ins, miopen_softmax{op}, ins->inputs().at(0), output);
} }
...@@ -145,7 +145,7 @@ struct miopen_apply ...@@ -145,7 +145,7 @@ struct miopen_apply
instruction_ref apply_gemm(instruction_ref ins) instruction_ref apply_gemm(instruction_ref ins)
{ {
auto&& op = any_cast<gemm>(ins->get_operator()); auto&& op = any_cast<op::gemm>(ins->get_operator());
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction( return prog->replace_instruction(
ins, miopen_gemm{op}, ins->inputs().at(0), ins->inputs().at(1), output); ins, miopen_gemm{op}, ins->inputs().at(0), ins->inputs().at(1), output);
...@@ -153,18 +153,18 @@ struct miopen_apply ...@@ -153,18 +153,18 @@ struct miopen_apply
instruction_ref apply_contiguous(instruction_ref ins) instruction_ref apply_contiguous(instruction_ref ins)
{ {
auto&& op = any_cast<contiguous>(ins->get_operator()); auto&& op = any_cast<op::contiguous>(ins->get_operator());
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction(ins, miopen_contiguous{op}, ins->inputs().at(0), output); return prog->replace_instruction(ins, miopen_contiguous{op}, ins->inputs().at(0), output);
} }
instruction_ref apply_batch_norm_inference(instruction_ref ins) instruction_ref apply_batch_norm_inference(instruction_ref ins)
{ {
auto&& op = any_cast<batch_norm_inference>(ins->get_operator()); auto&& op = any_cast<op::batch_norm_inference>(ins->get_operator());
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
shape old_shape = ins->inputs().at(1)->get_shape(); shape old_shape = ins->inputs().at(1)->get_shape();
std::vector<int64_t> new_shape{1, static_cast<int64_t>(old_shape.elements()), 1, 1}; std::vector<int64_t> new_shape{1, static_cast<int64_t>(old_shape.elements()), 1, 1};
auto reshape_op = reshape{new_shape}; auto reshape_op = op::reshape{new_shape};
std::vector<instruction_ref> reshapes; std::vector<instruction_ref> reshapes;
std::transform(ins->inputs().begin() + 1, std::transform(ins->inputs().begin() + 1,
ins->inputs().end(), ins->inputs().end(),
...@@ -182,7 +182,5 @@ struct miopen_apply ...@@ -182,7 +182,5 @@ struct miopen_apply
}; };
void lowering::apply(program& p) const { miopen_apply{&p, ctx}.apply(); } void lowering::apply(program& p) const { miopen_apply{&p, ctx}.apply(); }
} // namespace gpu } // namespace gpu
} // namespace migraph } // namespace migraph
#include <migraph/gpu/target.hpp> #include <migraph/gpu/target.hpp>
#include <migraph/gpu/lowering.hpp> #include <migraph/gpu/lowering.hpp>
#include <migraph/memory_coloring.hpp>
#include <migraph/gpu/write_literals.hpp> #include <migraph/gpu/write_literals.hpp>
#include <migraph/gpu/context.hpp> #include <migraph/gpu/context.hpp>
#include <migraph/gpu/eliminate_workspace.hpp> #include <migraph/gpu/eliminate_workspace.hpp>
...@@ -28,6 +29,7 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const ...@@ -28,6 +29,7 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const
simplify_reshapes{}, simplify_reshapes{},
dead_code_elimination{}, dead_code_elimination{},
lowering{ctx}, lowering{ctx},
memory_coloring{"hip::allocate"},
fuse_ops{}, fuse_ops{},
dead_code_elimination{}, dead_code_elimination{},
eliminate_contiguous{}, eliminate_contiguous{},
...@@ -45,10 +47,8 @@ std::string target::name() const { return "miopen"; } ...@@ -45,10 +47,8 @@ std::string target::name() const { return "miopen"; }
migraph::context target::get_context() const migraph::context target::get_context() const
{ {
return context{share(make_obj<miopen_handle>(&miopenCreate)), return context{
share(create_rocblas_handle_ptr())}; share(make_obj<miopen_handle>(&miopenCreate)), share(create_rocblas_handle_ptr()), {}};
} }
} // namespace gpu } // namespace gpu
} // namespace migraph } // namespace migraph
...@@ -37,7 +37,5 @@ void write_literals::apply(program& p) const ...@@ -37,7 +37,5 @@ void write_literals::apply(program& p) const
} }
} }
} }
} // namespace gpu } // namespace gpu
} // namespace migraph } // namespace migraph
...@@ -41,7 +41,7 @@ void after_literal_transpose() ...@@ -41,7 +41,7 @@ void after_literal_transpose()
auto l = p.add_literal(get_2x2()); auto l = p.add_literal(get_2x2());
EXPECT(p.get_shape().standard()); EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed()); EXPECT(not p.get_shape().transposed());
auto t = p.add_instruction(migraph::transpose{{1, 0}}, l); auto t = p.add_instruction(migraph::op::transpose{{1, 0}}, l);
p.add_instruction(pass_op{}, t); p.add_instruction(pass_op{}, t);
EXPECT(not p.get_shape().standard()); EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().transposed()); EXPECT(p.get_shape().transposed());
...@@ -57,7 +57,7 @@ void after_literal_broadcast() ...@@ -57,7 +57,7 @@ void after_literal_broadcast()
auto l2 = p.add_literal(get_2()); auto l2 = p.add_literal(get_2());
EXPECT(p.get_shape().standard()); EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().broadcasted()); EXPECT(not p.get_shape().broadcasted());
auto b = p.add_instruction(migraph::broadcast{}, l1, l2); auto b = p.add_instruction(migraph::op::broadcast{}, l1, l2);
p.add_instruction(pass_op{}, b); p.add_instruction(pass_op{}, b);
EXPECT(not p.get_shape().standard()); EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().broadcasted()); EXPECT(p.get_shape().broadcasted());
...@@ -72,7 +72,7 @@ void after_param_transpose() ...@@ -72,7 +72,7 @@ void after_param_transpose()
auto l = p.add_parameter("2x2", {migraph::shape::float_type, {2, 2}}); auto l = p.add_parameter("2x2", {migraph::shape::float_type, {2, 2}});
EXPECT(p.get_shape().standard()); EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed()); EXPECT(not p.get_shape().transposed());
auto t = p.add_instruction(migraph::transpose{{1, 0}}, l); auto t = p.add_instruction(migraph::op::transpose{{1, 0}}, l);
p.add_instruction(pass_op{}, t); p.add_instruction(pass_op{}, t);
EXPECT(not p.get_shape().standard()); EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().transposed()); EXPECT(p.get_shape().transposed());
...@@ -88,7 +88,7 @@ void after_param_broadcast() ...@@ -88,7 +88,7 @@ void after_param_broadcast()
auto l2 = p.add_parameter("2", {migraph::shape::float_type, {2}}); auto l2 = p.add_parameter("2", {migraph::shape::float_type, {2}});
EXPECT(p.get_shape().standard()); EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().broadcasted()); EXPECT(not p.get_shape().broadcasted());
auto b = p.add_instruction(migraph::broadcast{}, l1, l2); auto b = p.add_instruction(migraph::op::broadcast{}, l1, l2);
p.add_instruction(pass_op{}, b); p.add_instruction(pass_op{}, b);
EXPECT(not p.get_shape().standard()); EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().broadcasted()); EXPECT(p.get_shape().broadcasted());
......
...@@ -6,6 +6,109 @@ ...@@ -6,6 +6,109 @@
#include <migraph/verify.hpp> #include <migraph/verify.hpp>
#include "test.hpp" #include "test.hpp"
void slice_test()
{
{
migraph::program p;
std::vector<int> data(2 * 2 * 3);
std::iota(data.begin(), data.end(), 0);
migraph::shape s{migraph::shape::int32_type, {2, 2, 3}};
auto l0 = p.add_literal(migraph::literal{s, data});
p.add_instruction(migraph::op::slice{{2}, {1}, {3}}, l0);
migraph::shape s2{migraph::shape::int32_type, {2, 2, 2}, {6, 3, 1}};
EXPECT(p.get_shape() == s2);
p.compile(migraph::cpu::cpu_target{});
migraph::shape sresult{migraph::shape::int32_type, {2, 2, 2}, {4, 2, 1}};
auto result = p.eval({});
std::vector<int> gold = {1, 2, 4, 5, 7, 8, 10, 11};
std::vector<int> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraph::verify_range(results_vector, gold));
EXPECT(result.get_shape() == sresult);
}
{
migraph::program p;
std::vector<int> data(2 * 2 * 3);
std::iota(data.begin(), data.end(), 0);
migraph::shape s{migraph::shape::int32_type, {2, 2, 3}};
auto l0 = p.add_literal(migraph::literal{s, data});
p.add_instruction(migraph::op::slice{{0, 1, 2}, {0, 0, 0}, {2, 2, 2}}, l0);
migraph::shape s2{migraph::shape::int32_type, {2, 2, 2}, {6, 3, 1}};
EXPECT(p.get_shape() == s2);
p.compile(migraph::cpu::cpu_target{});
migraph::shape sresult{migraph::shape::int32_type, {2, 2, 2}, {4, 2, 1}};
auto result = p.eval({});
std::vector<int> gold = {0, 1, 3, 4, 6, 7, 9, 10};
std::vector<int> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraph::verify_range(results_vector, gold));
EXPECT(result.get_shape() == sresult);
}
}
void squeeze_test()
{
{
migraph::program p;
std::vector<float> data(4 * 3 * 3);
migraph::shape s1{migraph::shape::float_type, {4, 1, 3, 1, 3}};
migraph::shape s2{migraph::shape::float_type, {4, 3, 1, 3}};
auto l0 = p.add_literal(migraph::literal{s1, data});
p.add_instruction(migraph::op::squeeze{{1}}, l0);
p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({});
EXPECT(result.get_shape() == s2);
}
{
migraph::program p;
std::vector<float> data(4 * 3 * 3);
migraph::shape s1{migraph::shape::float_type, {4, 1, 3, 1, 3}};
migraph::shape s2{migraph::shape::float_type, {4, 1, 3, 3}};
auto l0 = p.add_literal(migraph::literal{s1, data});
p.add_instruction(migraph::op::squeeze{{3}}, l0);
p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({});
EXPECT(result.get_shape() == s2);
}
{
migraph::program p;
std::vector<float> data(4 * 3 * 3);
migraph::shape s1{migraph::shape::float_type, {4, 1, 3, 1, 3}};
migraph::shape s2{migraph::shape::float_type, {4, 3, 3}};
auto l0 = p.add_literal(migraph::literal{s1, data});
p.add_instruction(migraph::op::squeeze{}, l0);
p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({});
EXPECT(result.get_shape() == s2);
}
}
void unsqueeze_test()
{
{
migraph::program p;
std::vector<float> data(4 * 3 * 3);
migraph::shape s1{migraph::shape::float_type, {4, 3, 3}};
migraph::shape s2{migraph::shape::float_type, {4, 1, 3, 3}};
auto l0 = p.add_literal(migraph::literal{s1, data});
p.add_instruction(migraph::op::unsqueeze{{1}}, l0);
p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({});
EXPECT(result.get_shape() == s2);
}
{
migraph::program p;
std::vector<float> data(4 * 3 * 3);
migraph::shape s1{migraph::shape::float_type, {4, 3, 3}};
migraph::shape s2{migraph::shape::float_type, {4, 3, 1, 3}};
auto l0 = p.add_literal(migraph::literal{s1, data});
p.add_instruction(migraph::op::unsqueeze{{2}}, l0);
p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({});
EXPECT(result.get_shape() == s2);
}
}
void im2col_3x3_no_pad_identity_test() void im2col_3x3_no_pad_identity_test()
{ {
std::size_t f[2] = {3, 3}; std::size_t f[2] = {3, 3};
...@@ -24,7 +127,7 @@ void im2col_3x3_no_pad_identity_test() ...@@ -24,7 +127,7 @@ void im2col_3x3_no_pad_identity_test()
migraph::shape s_weights{migraph::shape::int32_type, {1, channels, f[0], f[1]}}; migraph::shape s_weights{migraph::shape::int32_type, {1, channels, f[0], f[1]}};
auto l_image = p.add_literal(migraph::literal{s_image, input}); auto l_image = p.add_literal(migraph::literal{s_image, input});
auto l_weights = p.add_literal(migraph::literal{s_weights, weights}); auto l_weights = p.add_literal(migraph::literal{s_weights, weights});
p.add_instruction(migraph::im2col{padding, stride, dilation}, l_image, l_weights); p.add_instruction(migraph::op::im2col{padding, stride, dilation}, l_image, l_weights);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
...@@ -53,7 +156,7 @@ void im2col_3x3_no_pad_test() ...@@ -53,7 +156,7 @@ void im2col_3x3_no_pad_test()
migraph::shape s_weights{migraph::shape::int32_type, {1, channels, f[0], f[1]}}; migraph::shape s_weights{migraph::shape::int32_type, {1, channels, f[0], f[1]}};
auto l_image = p.add_literal(migraph::literal{s_image, input}); auto l_image = p.add_literal(migraph::literal{s_image, input});
auto l_weights = p.add_literal(migraph::literal{s_weights, weights}); auto l_weights = p.add_literal(migraph::literal{s_weights, weights});
p.add_instruction(migraph::im2col{padding, stride, dilation}, l_image, l_weights); p.add_instruction(migraph::op::im2col{padding, stride, dilation}, l_image, l_weights);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
...@@ -85,7 +188,7 @@ void im2col_3x3_stride_2_no_pad_test() ...@@ -85,7 +188,7 @@ void im2col_3x3_stride_2_no_pad_test()
migraph::shape s_weights{migraph::shape::int32_type, {1, channels, f[0], f[1]}}; migraph::shape s_weights{migraph::shape::int32_type, {1, channels, f[0], f[1]}};
auto l_image = p.add_literal(migraph::literal{s_image, input}); auto l_image = p.add_literal(migraph::literal{s_image, input});
auto l_weights = p.add_literal(migraph::literal{s_weights, weights}); auto l_weights = p.add_literal(migraph::literal{s_weights, weights});
p.add_instruction(migraph::im2col{padding, stride, dilation}, l_image, l_weights); p.add_instruction(migraph::op::im2col{padding, stride, dilation}, l_image, l_weights);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
...@@ -118,7 +221,7 @@ void im2col_3x3_with_padding_test() ...@@ -118,7 +221,7 @@ void im2col_3x3_with_padding_test()
migraph::shape s_weights{migraph::shape::int32_type, {1, channels, f[0], f[1]}}; migraph::shape s_weights{migraph::shape::int32_type, {1, channels, f[0], f[1]}};
auto l_image = p.add_literal(migraph::literal{s_image, input}); auto l_image = p.add_literal(migraph::literal{s_image, input});
auto l_weights = p.add_literal(migraph::literal{s_weights, weights}); auto l_weights = p.add_literal(migraph::literal{s_weights, weights});
p.add_instruction(migraph::im2col{padding, stride, dilation}, l_image, l_weights); p.add_instruction(migraph::op::im2col{padding, stride, dilation}, l_image, l_weights);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
...@@ -160,7 +263,7 @@ void batch_norm_inference_test() ...@@ -160,7 +263,7 @@ void batch_norm_inference_test()
auto mean = p.add_literal(migraph::literal{vars, mean_data}); auto mean = p.add_literal(migraph::literal{vars, mean_data});
auto variance = p.add_literal(migraph::literal{vars, variance_data}); auto variance = p.add_literal(migraph::literal{vars, variance_data});
p.add_instruction(migraph::batch_norm_inference{}, x, scale, bias, mean, variance); p.add_instruction(migraph::op::batch_norm_inference{}, x, scale, bias, mean, variance);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
...@@ -190,7 +293,7 @@ void im2col_3x3_with_channels_identity_test() ...@@ -190,7 +293,7 @@ void im2col_3x3_with_channels_identity_test()
migraph::shape s_weights{migraph::shape::int32_type, {1, channels, f[0], f[1]}}; migraph::shape s_weights{migraph::shape::int32_type, {1, channels, f[0], f[1]}};
auto l_image = p.add_literal(migraph::literal{s_image, input}); auto l_image = p.add_literal(migraph::literal{s_image, input});
auto l_weights = p.add_literal(migraph::literal{s_weights, weights}); auto l_weights = p.add_literal(migraph::literal{s_weights, weights});
p.add_instruction(migraph::im2col{padding, stride, dilation}, l_image, l_weights); p.add_instruction(migraph::op::im2col{padding, stride, dilation}, l_image, l_weights);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
...@@ -206,7 +309,7 @@ void exp_test() ...@@ -206,7 +309,7 @@ void exp_test()
migraph::program p; migraph::program p;
migraph::shape s{migraph::shape::float_type, {3}}; migraph::shape s{migraph::shape::float_type, {3}};
auto l = p.add_literal(migraph::literal{s, {-1, 0, 1}}); auto l = p.add_literal(migraph::literal{s, {-1, 0, 1}});
p.add_instruction(migraph::exp{}, l); p.add_instruction(migraph::op::exp{}, l);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
...@@ -220,7 +323,7 @@ void sin_test() ...@@ -220,7 +323,7 @@ void sin_test()
migraph::program p; migraph::program p;
migraph::shape s{migraph::shape::float_type, {3}}; migraph::shape s{migraph::shape::float_type, {3}};
auto l = p.add_literal(migraph::literal{s, {-1, 0, 1}}); auto l = p.add_literal(migraph::literal{s, {-1, 0, 1}});
p.add_instruction(migraph::sin{}, l); p.add_instruction(migraph::op::sin{}, l);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
...@@ -234,7 +337,7 @@ void cos_test() ...@@ -234,7 +337,7 @@ void cos_test()
migraph::program p; migraph::program p;
migraph::shape s{migraph::shape::float_type, {3}}; migraph::shape s{migraph::shape::float_type, {3}};
auto l = p.add_literal(migraph::literal{s, {-1, 0, 1}}); auto l = p.add_literal(migraph::literal{s, {-1, 0, 1}});
p.add_instruction(migraph::cos{}, l); p.add_instruction(migraph::op::cos{}, l);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
...@@ -248,7 +351,7 @@ void tan_test() ...@@ -248,7 +351,7 @@ void tan_test()
migraph::program p; migraph::program p;
migraph::shape s{migraph::shape::float_type, {3}}; migraph::shape s{migraph::shape::float_type, {3}};
auto l = p.add_literal(migraph::literal{s, {-1, 0, 1}}); auto l = p.add_literal(migraph::literal{s, {-1, 0, 1}});
p.add_instruction(migraph::tan{}, l); p.add_instruction(migraph::op::tan{}, l);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
...@@ -263,7 +366,7 @@ void add_test() ...@@ -263,7 +366,7 @@ void add_test()
migraph::shape s{migraph::shape::float_type, {3}}; migraph::shape s{migraph::shape::float_type, {3}};
auto l1 = p.add_literal(migraph::literal{s, {-1, 0, 1}}); auto l1 = p.add_literal(migraph::literal{s, {-1, 0, 1}});
auto l2 = p.add_literal(migraph::literal{s, {1, 2, 3}}); auto l2 = p.add_literal(migraph::literal{s, {1, 2, 3}});
p.add_instruction(migraph::add{}, l1, l2); p.add_instruction(migraph::op::add{}, l1, l2);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
...@@ -282,7 +385,7 @@ void broadcast_test() ...@@ -282,7 +385,7 @@ void broadcast_test()
uint64_t axis = 0; uint64_t axis = 0;
auto l1 = p.add_literal(migraph::literal{a_shape, a_data}); auto l1 = p.add_literal(migraph::literal{a_shape, a_data});
auto l2 = p.add_literal(migraph::literal{b_shape, b_data}); auto l2 = p.add_literal(migraph::literal{b_shape, b_data});
p.add_instruction(migraph::broadcast{axis}, l1, l2); p.add_instruction(migraph::op::broadcast{axis}, l1, l2);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
auto output = result.get<int32_t>(); auto output = result.get<int32_t>();
...@@ -301,8 +404,8 @@ void add_broadcast_test() ...@@ -301,8 +404,8 @@ void add_broadcast_test()
uint64_t axis = 0; uint64_t axis = 0;
auto l1 = p.add_literal(migraph::literal{a_shape, a_data}); auto l1 = p.add_literal(migraph::literal{a_shape, a_data});
auto l2 = p.add_literal(migraph::literal{b_shape, b_data}); auto l2 = p.add_literal(migraph::literal{b_shape, b_data});
auto l3 = p.add_instruction(migraph::broadcast{axis}, l1, l2); auto l3 = p.add_instruction(migraph::op::broadcast{axis}, l1, l2);
p.add_instruction(migraph::add{}, l1, l3); p.add_instruction(migraph::op::add{}, l1, l3);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
EXPECT(result.get_shape().packed()); EXPECT(result.get_shape().packed());
...@@ -318,7 +421,7 @@ void sub_test() ...@@ -318,7 +421,7 @@ void sub_test()
migraph::shape s{migraph::shape::float_type, {3}}; migraph::shape s{migraph::shape::float_type, {3}};
auto l1 = p.add_literal(migraph::literal{s, {-1, 0, 1}}); auto l1 = p.add_literal(migraph::literal{s, {-1, 0, 1}});
auto l2 = p.add_literal(migraph::literal{s, {1, 2, 3}}); auto l2 = p.add_literal(migraph::literal{s, {1, 2, 3}});
p.add_instruction(migraph::sub{}, l1, l2); p.add_instruction(migraph::op::sub{}, l1, l2);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
...@@ -333,7 +436,7 @@ void mul_test() ...@@ -333,7 +436,7 @@ void mul_test()
migraph::shape s{migraph::shape::float_type, {3}}; migraph::shape s{migraph::shape::float_type, {3}};
auto l1 = p.add_literal(migraph::literal{s, {-1, 0, 1}}); auto l1 = p.add_literal(migraph::literal{s, {-1, 0, 1}});
auto l2 = p.add_literal(migraph::literal{s, {1, 2, 3}}); auto l2 = p.add_literal(migraph::literal{s, {1, 2, 3}});
p.add_instruction(migraph::mul{}, l1, l2); p.add_instruction(migraph::op::mul{}, l1, l2);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
...@@ -348,7 +451,7 @@ void div_test() ...@@ -348,7 +451,7 @@ void div_test()
migraph::shape s{migraph::shape::float_type, {3}}; migraph::shape s{migraph::shape::float_type, {3}};
auto l1 = p.add_literal(migraph::literal{s, {-1.0f, 0.5f, 1.0f}}); auto l1 = p.add_literal(migraph::literal{s, {-1.0f, 0.5f, 1.0f}});
auto l2 = p.add_literal(migraph::literal{s, {1.0f, 2.0f, 4.0f}}); auto l2 = p.add_literal(migraph::literal{s, {1.0f, 2.0f, 4.0f}});
p.add_instruction(migraph::div{}, l1, l2); p.add_instruction(migraph::op::div{}, l1, l2);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
...@@ -366,7 +469,7 @@ void reshape_test() ...@@ -366,7 +469,7 @@ void reshape_test()
migraph::program p; migraph::program p;
auto l = p.add_literal(migraph::literal{a_shape, data}); auto l = p.add_literal(migraph::literal{a_shape, data});
std::vector<int64_t> new_shape = {8, 3, 1, 1}; std::vector<int64_t> new_shape = {8, 3, 1, 1};
p.add_instruction(migraph::reshape{new_shape}, l); p.add_instruction(migraph::op::reshape{new_shape}, l);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
...@@ -377,7 +480,7 @@ void reshape_test() ...@@ -377,7 +480,7 @@ void reshape_test()
migraph::program p; migraph::program p;
auto l = p.add_literal(migraph::literal{a_shape, data}); auto l = p.add_literal(migraph::literal{a_shape, data});
std::vector<int64_t> new_shape = {1, 3, 4, 2}; std::vector<int64_t> new_shape = {1, 3, 4, 2};
p.add_instruction(migraph::reshape{new_shape}, l); p.add_instruction(migraph::op::reshape{new_shape}, l);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
...@@ -388,7 +491,7 @@ void reshape_test() ...@@ -388,7 +491,7 @@ void reshape_test()
migraph::program p; migraph::program p;
auto l = p.add_literal(migraph::literal{a_shape, data}); auto l = p.add_literal(migraph::literal{a_shape, data});
std::vector<int64_t> new_shape = {1, 3, 4, 2}; std::vector<int64_t> new_shape = {1, 3, 4, 2};
p.add_instruction(migraph::reshape{new_shape}, l); p.add_instruction(migraph::op::reshape{new_shape}, l);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
...@@ -436,7 +539,7 @@ void gemm_test() ...@@ -436,7 +539,7 @@ void gemm_test()
auto al = p.add_literal(migraph::literal{a_shape, a}); auto al = p.add_literal(migraph::literal{a_shape, a});
migraph::shape b_shape{migraph::shape::get_type<T>{}, {5, 3}}; migraph::shape b_shape{migraph::shape::get_type<T>{}, {5, 3}};
auto bl = p.add_literal(migraph::literal{b_shape, b}); auto bl = p.add_literal(migraph::literal{b_shape, b});
p.add_instruction(migraph::gemm{}, al, bl); p.add_instruction(migraph::op::gemm{}, al, bl);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<T> results_vector(12); std::vector<T> results_vector(12);
...@@ -491,7 +594,7 @@ void maxpool_test() ...@@ -491,7 +594,7 @@ void maxpool_test()
0.52119428, 2.07681108, 0.88494766, 1.51522756, 0.54275119, 0.6629802}; 0.52119428, 2.07681108, 0.88494766, 1.51522756, 0.54275119, 0.6629802};
migraph::shape a_shape{migraph::shape::float_type, {2, 3, 6, 6}}; migraph::shape a_shape{migraph::shape::float_type, {2, 3, 6, 6}};
auto al = p.add_literal(migraph::literal{a_shape, a}); auto al = p.add_literal(migraph::literal{a_shape, a});
p.add_instruction(migraph::pooling{"max", {{0, 0}}, {{2, 2}}, {{3, 2}}}, al); p.add_instruction(migraph::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{3, 2}}}, al);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
std::cout << result.get_shape() << std::endl; std::cout << result.get_shape() << std::endl;
...@@ -556,7 +659,7 @@ void softmax_test() ...@@ -556,7 +659,7 @@ void softmax_test()
migraph::shape a_shape{migraph::shape::float_type, {5, 3, 4, 2}}; migraph::shape a_shape{migraph::shape::float_type, {5, 3, 4, 2}};
auto al = p.add_literal(migraph::literal{a_shape, a}); auto al = p.add_literal(migraph::literal{a_shape, a});
p.add_instruction(migraph::softmax{}, al); p.add_instruction(migraph::op::softmax{}, al);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(120); std::vector<float> results_vector(120);
...@@ -618,7 +721,7 @@ void conv2d_test() ...@@ -618,7 +721,7 @@ void conv2d_test()
migraph::shape c_shape{migraph::shape::float_type, {2, 3, 3, 3}}; migraph::shape c_shape{migraph::shape::float_type, {2, 3, 3, 3}};
auto cl = p.add_literal(migraph::literal{c_shape, c}); auto cl = p.add_literal(migraph::literal{c_shape, c});
p.add_instruction(migraph::convolution{}, al, cl); p.add_instruction(migraph::op::convolution{}, al, cl);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
...@@ -674,7 +777,7 @@ void conv2d_padding_test() ...@@ -674,7 +777,7 @@ void conv2d_padding_test()
migraph::shape c_shape{migraph::shape::float_type, {2, 3, 3, 3}}; migraph::shape c_shape{migraph::shape::float_type, {2, 3, 3, 3}};
auto cl = p.add_literal(migraph::literal{c_shape, c}); auto cl = p.add_literal(migraph::literal{c_shape, c});
p.add_instruction(migraph::convolution{{{1, 1}}, {{1, 1}}}, al, cl); p.add_instruction(migraph::op::convolution{{{1, 1}}, {{1, 1}}}, al, cl);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
...@@ -735,7 +838,7 @@ void conv2d_padding_stride_test() ...@@ -735,7 +838,7 @@ void conv2d_padding_stride_test()
migraph::shape c_shape{migraph::shape::float_type, {2, 3, 3, 3}}; migraph::shape c_shape{migraph::shape::float_type, {2, 3, 3, 3}};
auto cl = p.add_literal(migraph::literal{c_shape, c}); auto cl = p.add_literal(migraph::literal{c_shape, c});
p.add_instruction(migraph::convolution{{{1, 1}}, {{2, 2}}}, al, cl); p.add_instruction(migraph::op::convolution{{{1, 1}}, {{2, 2}}}, al, cl);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
...@@ -754,7 +857,7 @@ void transpose_test() ...@@ -754,7 +857,7 @@ void transpose_test()
migraph::program p; migraph::program p;
auto l = p.add_literal(migraph::literal{a_shape, data}); auto l = p.add_literal(migraph::literal{a_shape, data});
std::vector<int64_t> perm = {0, 3, 1, 2}; std::vector<int64_t> perm = {0, 3, 1, 2};
p.add_instruction(migraph::transpose{perm}, l); p.add_instruction(migraph::op::transpose{perm}, l);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
...@@ -767,8 +870,8 @@ void transpose_test() ...@@ -767,8 +870,8 @@ void transpose_test()
migraph::program p; migraph::program p;
auto l = p.add_literal(migraph::literal{a_shape, data}); auto l = p.add_literal(migraph::literal{a_shape, data});
std::vector<int64_t> perm = {0, 3, 1, 2}; std::vector<int64_t> perm = {0, 3, 1, 2};
auto result = p.add_instruction(migraph::transpose{perm}, l); auto result = p.add_instruction(migraph::op::transpose{perm}, l);
p.add_instruction(migraph::contiguous{}, result); p.add_instruction(migraph::op::contiguous{}, result);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto result2 = p.eval({}); auto result2 = p.eval({});
...@@ -787,7 +890,7 @@ void contiguous_test() ...@@ -787,7 +890,7 @@ void contiguous_test()
migraph::program p; migraph::program p;
auto l = p.add_literal(migraph::literal{a_shape, data}); auto l = p.add_literal(migraph::literal{a_shape, data});
p.add_instruction(migraph::contiguous{}, l); p.add_instruction(migraph::op::contiguous{}, l);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
...@@ -795,12 +898,15 @@ void contiguous_test() ...@@ -795,12 +898,15 @@ void contiguous_test()
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<size_t> new_lens = {1, 3, 2, 2}; std::vector<size_t> new_lens = {1, 3, 2, 2};
std::vector<size_t> new_strides = {12, 1, 6, 3}; std::vector<size_t> new_strides = {12, 1, 6, 3};
std::vector<float> gold = {1, 4, 7, 10, 2, 5, 8, 11, 3, 6, 9, 0}; std::vector<float> gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11};
EXPECT(migraph::verify_range(results_vector, gold)); EXPECT(migraph::verify_range(results_vector, gold));
} }
int main() int main()
{ {
slice_test();
squeeze_test();
unsqueeze_test();
exp_test(); exp_test();
sin_test(); sin_test();
cos_test(); cos_test();
...@@ -814,7 +920,7 @@ int main() ...@@ -814,7 +920,7 @@ int main()
gemm_test<double>(); gemm_test<double>();
reshape_test(); reshape_test();
transpose_test(); transpose_test();
contiguous_test(); // contiguous_test();
softmax_test(); softmax_test();
// maxpool_test(); // maxpool_test();
conv2d_test(); conv2d_test();
......
...@@ -102,6 +102,7 @@ void float_aligned() ...@@ -102,6 +102,7 @@ void float_aligned()
int main() int main()
{ {
setenv("MIGRAPH_DISABLE_MEMORY_COLORING", "1", 1);
basic(); basic();
aligned(); aligned();
unaligned(); unaligned();
......
...@@ -18,8 +18,8 @@ void standard_op() ...@@ -18,8 +18,8 @@ void standard_op()
{ {
migraph::program p; migraph::program p;
auto l = p.add_literal(get_2x2()); auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraph::transpose{{1, 0}}, l); auto t = p.add_instruction(migraph::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraph::contiguous{}, t); auto c = p.add_instruction(migraph::op::contiguous{}, t);
p.add_instruction(pass_standard_op{}, c); p.add_instruction(pass_standard_op{}, c);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(p.begin(), p.end());
p.compile(eliminate_contiguous_target{}); p.compile(eliminate_contiguous_target{});
...@@ -30,8 +30,8 @@ void non_standard_op() ...@@ -30,8 +30,8 @@ void non_standard_op()
{ {
migraph::program p; migraph::program p;
auto l = p.add_literal(get_2x2()); auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraph::transpose{{1, 0}}, l); auto t = p.add_instruction(migraph::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraph::contiguous{}, t); auto c = p.add_instruction(migraph::op::contiguous{}, t);
p.add_instruction(pass_op{}, c); p.add_instruction(pass_op{}, c);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(p.begin(), p.end());
p.compile(eliminate_contiguous_target{}); p.compile(eliminate_contiguous_target{});
......
...@@ -35,10 +35,7 @@ std::future<typename std::result_of<Function()>::type> detach_async(Function&& f ...@@ -35,10 +35,7 @@ std::future<typename std::result_of<Function()>::type> detach_async(Function&& f
std::thread(std::move(task)).detach(); std::thread(std::move(task)).detach();
return std::move(fut); return std::move(fut);
} }
else return std::async(std::launch::deferred, std::forward<Function>(f));
{
return std::async(std::launch::deferred, std::forward<Function>(f));
}
} }
struct auto_print struct auto_print
...@@ -157,8 +154,8 @@ struct test_literals ...@@ -157,8 +154,8 @@ struct test_literals
generate_literal(migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}})); generate_literal(migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}}));
auto weights = p.add_literal( auto weights = p.add_literal(
generate_literal(migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}})); generate_literal(migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}}));
auto conv = p.add_instruction(migraph::convolution{}, input, weights); auto conv = p.add_instruction(migraph::op::convolution{}, input, weights);
p.add_instruction(migraph::activation{"relu"}, conv); p.add_instruction(migraph::op::activation{"relu"}, conv);
return p; return p;
} }
}; };
...@@ -171,7 +168,7 @@ struct test_add ...@@ -171,7 +168,7 @@ struct test_add
migraph::shape s{migraph::shape::float_type, {3}}; migraph::shape s{migraph::shape::float_type, {3}};
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s); auto y = p.add_parameter("y", s);
p.add_instruction(migraph::add{}, x, y); p.add_instruction(migraph::op::add{}, x, y);
return p; return p;
} }
}; };
...@@ -184,8 +181,8 @@ struct test_add_broadcast ...@@ -184,8 +181,8 @@ struct test_add_broadcast
migraph::shape s{migraph::shape::float_type, {3}}; migraph::shape s{migraph::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraph::shape::float_type, {2, 2, 3}}); auto x = p.add_parameter("x", {migraph::shape::float_type, {2, 2, 3}});
auto y = p.add_parameter("y", {migraph::shape::float_type, {2, 2}}); auto y = p.add_parameter("y", {migraph::shape::float_type, {2, 2}});
auto by = p.add_instruction(migraph::broadcast{0}, x, y); auto by = p.add_instruction(migraph::op::broadcast{0}, x, y);
p.add_instruction(migraph::add{}, x, by); p.add_instruction(migraph::op::add{}, x, by);
return p; return p;
} }
}; };
...@@ -198,8 +195,8 @@ struct test_add_broadcast2 ...@@ -198,8 +195,8 @@ struct test_add_broadcast2
migraph::shape s{migraph::shape::float_type, {3}}; migraph::shape s{migraph::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraph::shape::float_type, {2, 3, 4}}); auto x = p.add_parameter("x", {migraph::shape::float_type, {2, 3, 4}});
auto y = p.add_parameter("y", {migraph::shape::float_type, {3}}); auto y = p.add_parameter("y", {migraph::shape::float_type, {3}});
auto by = p.add_instruction(migraph::broadcast{1}, x, y); auto by = p.add_instruction(migraph::op::broadcast{1}, x, y);
p.add_instruction(migraph::add{}, x, by); p.add_instruction(migraph::op::add{}, x, by);
return p; return p;
} }
}; };
...@@ -212,8 +209,8 @@ struct test_add_broadcast3 ...@@ -212,8 +209,8 @@ struct test_add_broadcast3
migraph::shape s{migraph::shape::float_type, {3}}; migraph::shape s{migraph::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraph::shape::float_type, {2, 4, 5}}); auto x = p.add_parameter("x", {migraph::shape::float_type, {2, 4, 5}});
auto y = p.add_parameter("y", {migraph::shape::float_type, {4}}); auto y = p.add_parameter("y", {migraph::shape::float_type, {4}});
auto by = p.add_instruction(migraph::broadcast{1}, x, y); auto by = p.add_instruction(migraph::op::broadcast{1}, x, y);
p.add_instruction(migraph::add{}, x, by); p.add_instruction(migraph::op::add{}, x, by);
return p; return p;
} }
}; };
...@@ -226,8 +223,8 @@ struct test_add_broadcast4 ...@@ -226,8 +223,8 @@ struct test_add_broadcast4
migraph::shape s{migraph::shape::float_type, {3}}; migraph::shape s{migraph::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraph::shape::float_type, {2, 3, 5}}); auto x = p.add_parameter("x", {migraph::shape::float_type, {2, 3, 5}});
auto y = p.add_parameter("y", {migraph::shape::float_type, {3}}); auto y = p.add_parameter("y", {migraph::shape::float_type, {3}});
auto by = p.add_instruction(migraph::broadcast{1}, x, y); auto by = p.add_instruction(migraph::op::broadcast{1}, x, y);
p.add_instruction(migraph::add{}, x, by); p.add_instruction(migraph::op::add{}, x, by);
return p; return p;
} }
}; };
...@@ -240,8 +237,8 @@ struct test_add_broadcast5 ...@@ -240,8 +237,8 @@ struct test_add_broadcast5
migraph::shape s{migraph::shape::float_type, {3}}; migraph::shape s{migraph::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraph::shape::float_type, {2, 4, 8}}); auto x = p.add_parameter("x", {migraph::shape::float_type, {2, 4, 8}});
auto y = p.add_parameter("y", {migraph::shape::float_type, {4}}); auto y = p.add_parameter("y", {migraph::shape::float_type, {4}});
auto by = p.add_instruction(migraph::broadcast{1}, x, y); auto by = p.add_instruction(migraph::op::broadcast{1}, x, y);
p.add_instruction(migraph::add{}, x, by); p.add_instruction(migraph::op::add{}, x, by);
return p; return p;
} }
}; };
...@@ -252,7 +249,7 @@ struct test_softmax ...@@ -252,7 +249,7 @@ struct test_softmax
{ {
migraph::program p; migraph::program p;
auto x = p.add_parameter("x", migraph::shape{migraph::shape::float_type, {5, 3, 4, 2}}); auto x = p.add_parameter("x", migraph::shape{migraph::shape::float_type, {5, 3, 4, 2}});
p.add_instruction(migraph::softmax{}, x); p.add_instruction(migraph::op::softmax{}, x);
return p; return p;
} }
}; };
...@@ -263,7 +260,7 @@ struct test_softmax2 ...@@ -263,7 +260,7 @@ struct test_softmax2
{ {
migraph::program p; migraph::program p;
auto x = p.add_parameter("x", migraph::shape{migraph::shape::float_type, {1, 1000, 1, 1}}); auto x = p.add_parameter("x", migraph::shape{migraph::shape::float_type, {1, 1000, 1, 1}});
p.add_instruction(migraph::softmax{}, x); p.add_instruction(migraph::op::softmax{}, x);
return p; return p;
} }
}; };
...@@ -276,7 +273,7 @@ struct test_conv ...@@ -276,7 +273,7 @@ struct test_conv
auto input = p.add_parameter("x", migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}}); auto input = p.add_parameter("x", migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}});
auto weights = auto weights =
p.add_parameter("w", migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}}); p.add_parameter("w", migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}});
p.add_instruction(migraph::convolution{}, input, weights); p.add_instruction(migraph::op::convolution{}, input, weights);
return p; return p;
} }
}; };
...@@ -290,7 +287,7 @@ struct test_conv2 ...@@ -290,7 +287,7 @@ struct test_conv2
p.add_parameter("x", migraph::shape{migraph::shape::float_type, {1, 512, 28, 28}}); p.add_parameter("x", migraph::shape{migraph::shape::float_type, {1, 512, 28, 28}});
auto weights = auto weights =
p.add_parameter("w", migraph::shape{migraph::shape::float_type, {256, 512, 1, 1}}); p.add_parameter("w", migraph::shape{migraph::shape::float_type, {256, 512, 1, 1}});
p.add_instruction(migraph::convolution{{0, 0}, {1, 1}, {1, 1}}, input, weights); p.add_instruction(migraph::op::convolution{{0, 0}, {1, 1}, {1, 1}}, input, weights);
return p; return p;
} }
}; };
...@@ -303,8 +300,8 @@ struct test_conv_relu ...@@ -303,8 +300,8 @@ struct test_conv_relu
auto input = p.add_parameter("x", migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}}); auto input = p.add_parameter("x", migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}});
auto weights = auto weights =
p.add_parameter("w", migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}}); p.add_parameter("w", migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}});
auto conv = p.add_instruction(migraph::convolution{}, input, weights); auto conv = p.add_instruction(migraph::op::convolution{}, input, weights);
p.add_instruction(migraph::activation{"relu"}, conv); p.add_instruction(migraph::op::activation{"relu"}, conv);
return p; return p;
} }
}; };
...@@ -316,8 +313,8 @@ struct test_add_relu ...@@ -316,8 +313,8 @@ struct test_add_relu
migraph::program p; migraph::program p;
auto x = p.add_parameter("x", migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}}); auto x = p.add_parameter("x", migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}});
auto y = p.add_parameter("y", migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}}); auto y = p.add_parameter("y", migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}});
auto add = p.add_instruction(migraph::add{}, x, y); auto add = p.add_instruction(migraph::op::add{}, x, y);
p.add_instruction(migraph::activation{"relu"}, add); p.add_instruction(migraph::op::activation{"relu"}, add);
return p; return p;
} }
}; };
...@@ -331,9 +328,9 @@ struct test_conv_pooling ...@@ -331,9 +328,9 @@ struct test_conv_pooling
p.add_parameter("x", migraph::shape{migraph::shape::float_type, {4, 3, 32, 32}}); p.add_parameter("x", migraph::shape{migraph::shape::float_type, {4, 3, 32, 32}});
auto weights = auto weights =
p.add_parameter("w", migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}}); p.add_parameter("w", migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}});
auto conv = p.add_instruction(migraph::convolution{}, input, weights); auto conv = p.add_instruction(migraph::op::convolution{}, input, weights);
auto pooling = p.add_instruction(migraph::pooling{"max"}, conv); auto pooling = p.add_instruction(migraph::op::pooling{"max"}, conv);
p.add_instruction(migraph::activation{"relu"}, pooling); p.add_instruction(migraph::op::activation{"relu"}, pooling);
return p; return p;
} }
}; };
...@@ -345,7 +342,7 @@ struct test_gemm ...@@ -345,7 +342,7 @@ struct test_gemm
migraph::program p; migraph::program p;
auto a = p.add_parameter("a", migraph::shape{migraph::shape::float_type, {4, 5}}); auto a = p.add_parameter("a", migraph::shape{migraph::shape::float_type, {4, 5}});
auto b = p.add_parameter("b", migraph::shape{migraph::shape::float_type, {5, 3}}); auto b = p.add_parameter("b", migraph::shape{migraph::shape::float_type, {5, 3}});
p.add_instruction(migraph::gemm{}, a, b); p.add_instruction(migraph::op::gemm{}, a, b);
return p; return p;
} }
}; };
...@@ -357,7 +354,7 @@ struct test_gemm_ld ...@@ -357,7 +354,7 @@ struct test_gemm_ld
migraph::program p; migraph::program p;
auto a = p.add_parameter("a", migraph::shape{migraph::shape::float_type, {4, 5}, {10, 1}}); auto a = p.add_parameter("a", migraph::shape{migraph::shape::float_type, {4, 5}, {10, 1}});
auto b = p.add_parameter("b", migraph::shape{migraph::shape::float_type, {5, 3}, {20, 1}}); auto b = p.add_parameter("b", migraph::shape{migraph::shape::float_type, {5, 3}, {20, 1}});
p.add_instruction(migraph::gemm{}, a, b); p.add_instruction(migraph::op::gemm{}, a, b);
return p; return p;
} }
}; };
...@@ -369,8 +366,8 @@ struct test_gemm_transposeb ...@@ -369,8 +366,8 @@ struct test_gemm_transposeb
migraph::program p; migraph::program p;
auto a = p.add_parameter("a", migraph::shape{migraph::shape::float_type, {4, 5}}); auto a = p.add_parameter("a", migraph::shape{migraph::shape::float_type, {4, 5}});
auto b = p.add_parameter("b", migraph::shape{migraph::shape::float_type, {3, 5}}); auto b = p.add_parameter("b", migraph::shape{migraph::shape::float_type, {3, 5}});
auto bt = p.add_instruction(migraph::transpose{{1, 0}}, b); auto bt = p.add_instruction(migraph::op::transpose{{1, 0}}, b);
p.add_instruction(migraph::gemm{}, a, bt); p.add_instruction(migraph::op::gemm{}, a, bt);
return p; return p;
} }
}; };
...@@ -382,8 +379,8 @@ struct test_gemm_transposea ...@@ -382,8 +379,8 @@ struct test_gemm_transposea
migraph::program p; migraph::program p;
auto a = p.add_parameter("a", migraph::shape{migraph::shape::float_type, {5, 4}}); auto a = p.add_parameter("a", migraph::shape{migraph::shape::float_type, {5, 4}});
auto b = p.add_parameter("b", migraph::shape{migraph::shape::float_type, {5, 3}}); auto b = p.add_parameter("b", migraph::shape{migraph::shape::float_type, {5, 3}});
auto at = p.add_instruction(migraph::transpose{{1, 0}}, a); auto at = p.add_instruction(migraph::op::transpose{{1, 0}}, a);
p.add_instruction(migraph::gemm{}, at, b); p.add_instruction(migraph::op::gemm{}, at, b);
return p; return p;
} }
}; };
...@@ -395,9 +392,9 @@ struct test_gemm_transposeab ...@@ -395,9 +392,9 @@ struct test_gemm_transposeab
migraph::program p; migraph::program p;
auto a = p.add_parameter("a", migraph::shape{migraph::shape::float_type, {5, 4}}); auto a = p.add_parameter("a", migraph::shape{migraph::shape::float_type, {5, 4}});
auto b = p.add_parameter("b", migraph::shape{migraph::shape::float_type, {3, 5}}); auto b = p.add_parameter("b", migraph::shape{migraph::shape::float_type, {3, 5}});
auto at = p.add_instruction(migraph::transpose{{1, 0}}, a); auto at = p.add_instruction(migraph::op::transpose{{1, 0}}, a);
auto bt = p.add_instruction(migraph::transpose{{1, 0}}, b); auto bt = p.add_instruction(migraph::op::transpose{{1, 0}}, b);
p.add_instruction(migraph::gemm{}, at, bt); p.add_instruction(migraph::op::gemm{}, at, bt);
return p; return p;
} }
}; };
...@@ -409,7 +406,7 @@ struct test_contiguous ...@@ -409,7 +406,7 @@ struct test_contiguous
migraph::program p; migraph::program p;
migraph::shape s{migraph::shape::float_type, {4, 4, 4, 3}, {48, 4, 1, 16}}; migraph::shape s{migraph::shape::float_type, {4, 4, 4, 3}, {48, 4, 1, 16}};
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
p.add_instruction(migraph::contiguous{}, x); p.add_instruction(migraph::op::contiguous{}, x);
EXPECT(p.get_shape().standard()); EXPECT(p.get_shape().standard());
return p; return p;
} }
...@@ -423,8 +420,8 @@ struct test_transpose ...@@ -423,8 +420,8 @@ struct test_transpose
migraph::shape s{migraph::shape::float_type, {4, 3, 4, 4}}; migraph::shape s{migraph::shape::float_type, {4, 3, 4, 4}};
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
std::vector<int64_t> perm = {0, 2, 3, 1}; std::vector<int64_t> perm = {0, 2, 3, 1};
auto l = p.add_instruction(migraph::transpose{perm}, x); auto l = p.add_instruction(migraph::op::transpose{perm}, x);
p.add_instruction(migraph::contiguous{}, l); p.add_instruction(migraph::op::contiguous{}, l);
return p; return p;
} }
}; };
...@@ -447,7 +444,7 @@ struct test_batchnorm_inference_2 ...@@ -447,7 +444,7 @@ struct test_batchnorm_inference_2
auto bias = p.add_literal(migraph::abs(migraph::generate_literal(vars, 2))); auto bias = p.add_literal(migraph::abs(migraph::generate_literal(vars, 2)));
auto mean = p.add_literal(migraph::abs(migraph::generate_literal(vars, 3))); auto mean = p.add_literal(migraph::abs(migraph::generate_literal(vars, 3)));
auto variance = p.add_literal(migraph::abs(migraph::generate_literal(vars, 4))); auto variance = p.add_literal(migraph::abs(migraph::generate_literal(vars, 4)));
p.add_instruction(migraph::batch_norm_inference{}, x, scale, bias, mean, variance); p.add_instruction(migraph::op::batch_norm_inference{}, x, scale, bias, mean, variance);
return p; return p;
} }
}; };
...@@ -470,7 +467,7 @@ struct test_batchnorm_inference ...@@ -470,7 +467,7 @@ struct test_batchnorm_inference
auto bias = p.add_literal(migraph::abs(migraph::generate_literal(vars, 2))); auto bias = p.add_literal(migraph::abs(migraph::generate_literal(vars, 2)));
auto mean = p.add_literal(migraph::abs(migraph::generate_literal(vars, 3))); auto mean = p.add_literal(migraph::abs(migraph::generate_literal(vars, 3)));
auto variance = p.add_literal(migraph::abs(migraph::generate_literal(vars, 4))); auto variance = p.add_literal(migraph::abs(migraph::generate_literal(vars, 4)));
p.add_instruction(migraph::batch_norm_inference{}, x, scale, bias, mean, variance); p.add_instruction(migraph::op::batch_norm_inference{}, x, scale, bias, mean, variance);
return p; return p;
} }
}; };
...@@ -486,12 +483,12 @@ struct test_conv_bn ...@@ -486,12 +483,12 @@ struct test_conv_bn
migraph::shape vars{migraph::shape::float_type, {64}}; migraph::shape vars{migraph::shape::float_type, {64}};
auto x = p.add_parameter("x", xs); auto x = p.add_parameter("x", xs);
auto w = p.add_parameter("w", ws); auto w = p.add_parameter("w", ws);
auto conv = p.add_instruction(migraph::convolution{{3, 3}, {2, 2}, {1, 1}}, x, w); auto conv = p.add_instruction(migraph::op::convolution{{3, 3}, {2, 2}, {1, 1}}, x, w);
auto scale = p.add_literal(migraph::abs(migraph::generate_literal(vars, 1))); auto scale = p.add_literal(migraph::abs(migraph::generate_literal(vars, 1)));
auto bias = p.add_literal(migraph::abs(migraph::generate_literal(vars, 2))); auto bias = p.add_literal(migraph::abs(migraph::generate_literal(vars, 2)));
auto mean = p.add_literal(migraph::abs(migraph::generate_literal(vars, 3))); auto mean = p.add_literal(migraph::abs(migraph::generate_literal(vars, 3)));
auto variance = p.add_literal(migraph::abs(migraph::generate_literal(vars, 4))); auto variance = p.add_literal(migraph::abs(migraph::generate_literal(vars, 4)));
p.add_instruction(migraph::batch_norm_inference{}, conv, scale, bias, mean, variance); p.add_instruction(migraph::op::batch_norm_inference{}, conv, scale, bias, mean, variance);
return p; return p;
} }
}; };
...@@ -507,15 +504,15 @@ struct test_conv_bn_relu_pooling ...@@ -507,15 +504,15 @@ struct test_conv_bn_relu_pooling
migraph::shape vars{migraph::shape::float_type, {64}}; migraph::shape vars{migraph::shape::float_type, {64}};
auto x = p.add_parameter("x", xs); auto x = p.add_parameter("x", xs);
auto w = p.add_parameter("w", ws); auto w = p.add_parameter("w", ws);
auto conv = p.add_instruction(migraph::convolution{{3, 3}, {2, 2}, {1, 1}}, x, w); auto conv = p.add_instruction(migraph::op::convolution{{3, 3}, {2, 2}, {1, 1}}, x, w);
auto scale = p.add_literal(migraph::abs(migraph::generate_literal(vars, 1))); auto scale = p.add_literal(migraph::abs(migraph::generate_literal(vars, 1)));
auto bias = p.add_literal(migraph::abs(migraph::generate_literal(vars, 2))); auto bias = p.add_literal(migraph::abs(migraph::generate_literal(vars, 2)));
auto mean = p.add_literal(migraph::abs(migraph::generate_literal(vars, 3))); auto mean = p.add_literal(migraph::abs(migraph::generate_literal(vars, 3)));
auto variance = p.add_literal(migraph::abs(migraph::generate_literal(vars, 4))); auto variance = p.add_literal(migraph::abs(migraph::generate_literal(vars, 4)));
auto bn = auto bn = p.add_instruction(
p.add_instruction(migraph::batch_norm_inference{}, conv, scale, bias, mean, variance); migraph::op::batch_norm_inference{}, conv, scale, bias, mean, variance);
auto relu = p.add_instruction(migraph::activation{"relu"}, bn); auto relu = p.add_instruction(migraph::op::activation{"relu"}, bn);
p.add_instruction(migraph::pooling{"average", {1, 1}, {2, 2}, {3, 3}}, relu); p.add_instruction(migraph::op::pooling{"average", {1, 1}, {2, 2}, {3, 3}}, relu);
return p; return p;
} }
}; };
...@@ -530,7 +527,8 @@ struct test_conv_bn_relu_pooling2 ...@@ -530,7 +527,8 @@ struct test_conv_bn_relu_pooling2
auto bias = p.add_literal(migraph::abs(migraph::generate_literal(vars, 2 + channels))); auto bias = p.add_literal(migraph::abs(migraph::generate_literal(vars, 2 + channels)));
auto mean = p.add_literal(migraph::abs(migraph::generate_literal(vars, 3 + channels))); auto mean = p.add_literal(migraph::abs(migraph::generate_literal(vars, 3 + channels)));
auto variance = p.add_literal(migraph::abs(migraph::generate_literal(vars, 4 + channels))); auto variance = p.add_literal(migraph::abs(migraph::generate_literal(vars, 4 + channels)));
return p.add_instruction(migraph::batch_norm_inference{}, x, scale, bias, mean, variance); return p.add_instruction(
migraph::op::batch_norm_inference{}, x, scale, bias, mean, variance);
} }
migraph::program create_program() const migraph::program create_program() const
{ {
...@@ -542,15 +540,15 @@ struct test_conv_bn_relu_pooling2 ...@@ -542,15 +540,15 @@ struct test_conv_bn_relu_pooling2
migraph::shape ws2{migraph::shape::float_type, {2048, 1024, 1, 1}}; migraph::shape ws2{migraph::shape::float_type, {2048, 1024, 1, 1}};
auto x1 = p.add_parameter("x1", xs1); auto x1 = p.add_parameter("x1", xs1);
auto w1 = p.add_parameter("w1", ws1); auto w1 = p.add_parameter("w1", ws1);
auto conv1 = p.add_instruction(migraph::convolution{{0, 0}, {1, 1}, {1, 1}}, x1, w1); auto conv1 = p.add_instruction(migraph::op::convolution{{0, 0}, {1, 1}, {1, 1}}, x1, w1);
auto bn1 = add_bn(p, conv1, 2048); auto bn1 = add_bn(p, conv1, 2048);
auto x2 = p.add_parameter("x2", xs2); auto x2 = p.add_parameter("x2", xs2);
auto w2 = p.add_parameter("w2", ws2); auto w2 = p.add_parameter("w2", ws2);
auto conv2 = p.add_instruction(migraph::convolution{{0, 0}, {2, 2}, {1, 1}}, x2, w2); auto conv2 = p.add_instruction(migraph::op::convolution{{0, 0}, {2, 2}, {1, 1}}, x2, w2);
auto bn2 = add_bn(p, conv2, 2048); auto bn2 = add_bn(p, conv2, 2048);
auto add = p.add_instruction(migraph::add{}, bn1, bn2); auto add = p.add_instruction(migraph::op::add{}, bn1, bn2);
auto relu = p.add_instruction(migraph::activation{"relu"}, add); auto relu = p.add_instruction(migraph::op::activation{"relu"}, add);
p.add_instruction(migraph::pooling{"average", {1, 1}, {2, 2}, {3, 3}}, relu); p.add_instruction(migraph::op::pooling{"average", {1, 1}, {2, 2}, {3, 3}}, relu);
return p; return p;
} }
}; };
......
#include <migraph/matcher.hpp>
#include <migraph/iterator_for.hpp>
#include <test.hpp>
#include <basic_ops.hpp>
namespace matchers = migraph::matchers;
template <class M>
migraph::matchers::matcher_result find_match(migraph::program& p, M&& m)
{
migraph::matchers::matcher_result result;
for(auto ins : migraph::iterator_for(p))
{
result = migraph::matchers::match_instruction(p, ins, m);
if(result.result != p.end())
return result;
}
return result;
}
void match1()
{
migraph::program p;
auto l = p.add_literal(1);
auto m = matchers::standard_shape();
auto r = find_match(p, m);
EXPECT(bool{r.result == l});
}
void match_name1()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum");
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
void match_name2()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("min");
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
void match_name3()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
void match_arg1()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::arg(0)(matchers::name("@literal")),
matchers::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
void match_arg2()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m =
matchers::name("sum")(matchers::arg(0)(matchers::name("sum")), matchers::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
void match_arg3()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::arg(1)(matchers::name("@literal")),
matchers::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
void match_arg4()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
auto pass = p.add_instruction(pass_op{}, sum);
auto m =
matchers::name("pass")(matchers::arg(0)(matchers::name("sum")), matchers::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == pass});
}
void match_arg5()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m =
matchers::name("pass")(matchers::arg(1)(matchers::name("sum")), matchers::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
void match_arg6()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::arg(0)(matchers::name("@literal")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
void match_arg7()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::arg(0)(matchers::name("@literal")),
matchers::arg(1)(matchers::name("@literal")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
void match_args1()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(
matchers::args(matchers::name("@literal"), matchers::name("@literal")),
matchers::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
void match_args2()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m =
matchers::name("sum")(matchers::args(matchers::name("@literal"), matchers::name("sum")),
matchers::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
void match_args3()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::args(matchers::name("@literal")),
matchers::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
void match_args4()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, two);
p.add_instruction(pass_op{}, sum2);
auto m =
matchers::name("sum")(matchers::args(matchers::name("sum"), matchers::name("@literal")),
matchers::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == sum2});
}
void match_args5()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m =
matchers::name("sum")(matchers::args(matchers::name("sum"), matchers::name("@literal")),
matchers::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
void match_args6()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
auto pass = p.add_instruction(pass_op{}, sum);
auto m =
matchers::name("pass")(matchers::args(matchers::name("sum")), matchers::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == pass});
}
void match_args7()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
auto pass = p.add_instruction(pass_op{}, sum);
auto m = matchers::name("pass")(matchers::args(matchers::name("sum")(matchers::args(
matchers::name("@literal"), matchers::name("@literal")))),
matchers::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == pass});
}
void match_all_of1()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::all_of(matchers::arg(0)(matchers::name("@literal")),
matchers::arg(1)(matchers::name("@literal"))));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
void match_all_of2()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::all_of(matchers::arg(0)(matchers::name("sum")),
matchers::arg(1)(matchers::name("@literal"))));
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
void match_any_of1()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::any_of(matchers::arg(0)(matchers::name("sum")),
matchers::arg(1)(matchers::name("@literal"))));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
void match_any_of2()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::any_of(matchers::arg(0)(matchers::name("sum")),
matchers::arg(1)(matchers::name("sum"))));
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
void match_none_of1()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::none_of(matchers::arg(0)(matchers::name("sum")),
matchers::arg(1)(matchers::name("sum"))));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
void match_none_of2()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::none_of(matchers::arg(0)(matchers::name("@literal")),
matchers::arg(1)(matchers::name("@literal"))));
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
void match_bind1()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
auto pass = p.add_instruction(pass_op{}, sum);
auto m = matchers::name("pass")(
matchers::args(
matchers::name("sum")(matchers::args(matchers::name("@literal").bind("one"),
matchers::name("@literal").bind("two")))
.bind("sum")),
matchers::standard_shape())
.bind("pass");
auto r = find_match(p, m);
EXPECT(bool{r.instructions.at("one") == one});
EXPECT(bool{r.instructions.at("two") == two});
EXPECT(bool{r.instructions.at("sum") == sum});
EXPECT(bool{r.instructions.at("pass") == pass});
EXPECT(bool{r.result == pass});
}
struct match_find_sum
{
migraph::instruction_ref ins;
auto matcher() const { return matchers::name("sum"); }
void apply(migraph::program&, matchers::matcher_result r) const
{
EXPECT(bool{r.result == ins});
}
};
struct match_find_literal
{
migraph::instruction_ref ins;
auto matcher() const { return matchers::name("@literal"); }
void apply(migraph::program&, matchers::matcher_result r) const
{
EXPECT(bool{r.result != ins});
EXPECT(r.result->name() == "@literal");
}
};
void match_finder()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
matchers::find_matches(p, match_find_sum{sum}, match_find_literal{sum});
}
int main()
{
match1();
match_name1();
match_name2();
match_name3();
match_arg1();
match_arg2();
match_arg3();
match_arg4();
match_arg5();
match_arg6();
match_arg7();
match_args1();
match_args2();
match_args3();
match_args4();
match_args5();
match_args6();
match_args7();
match_all_of1();
match_all_of2();
match_any_of1();
match_any_of2();
match_none_of1();
match_none_of2();
match_bind1();
match_finder();
}
#include <migraph/memory_coloring.hpp>
#include <migraph/operators.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
struct memory_coloring_target
{
std::string name() const { return "memory_coloring"; }
std::vector<migraph::pass> get_passes(migraph::context&) const
{
return {migraph::memory_coloring{"allocate"}};
}
migraph::context get_context() const { return {}; }
};
struct allocate
{
migraph::shape s{};
std::string name() const { return "allocate"; }
migraph::shape compute_shape(const std::vector<migraph::shape>& inputs) const
{
migraph::check_shapes{inputs, *this}.has(1);
return inputs.front();
}
migraph::argument compute(migraph::context&,
const migraph::shape& output_shape,
const std::vector<migraph::argument>&) const
{
return {output_shape};
}
};
// A custom test operator that takes a single argument and an allocation
// This operator's output is an operand alias of argument 1
struct pass_memory
{
std::string name() const { return "memory_coloring::pass_memory"; }
migraph::shape compute_shape(const std::vector<migraph::shape>& inputs) const
{
migraph::check_shapes{inputs, *this}.has(2);
return inputs.at(1);
}
migraph::argument compute(migraph::context&,
const migraph::shape&,
const std::vector<migraph::argument>& args) const
{
return args[1];
}
};
// The previous existing test
void test1()
{
migraph::program p;
auto a0 = p.add_outline(migraph::shape{migraph::shape::float_type, {8}});
auto a1 = p.add_instruction(allocate{}, a0);
auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = p.add_outline(migraph::shape{migraph::shape::float_type, {40}});
auto p2 = p.add_instruction(allocate{}, a2);
p.add_instruction(pass_op{}, p2, p1);
p.compile(memory_coloring_target{});
EXPECT(p.get_parameter_shape("scratch").bytes() == 192);
}
// This test uses the pass_memory operator
void test2()
{
migraph::program p;
auto input = p.add_parameter("input", migraph::shape{migraph::shape::float_type, {16}});
auto a0 = p.add_outline(migraph::shape{migraph::shape::float_type, {128}});
auto a1 = p.add_instruction(allocate{}, a0);
auto p1 = p.add_instruction(pass_memory{}, input, a1);
auto a2 = p.add_outline(migraph::shape{migraph::shape::float_type, {40}});
auto p2 = p.add_instruction(allocate{}, a2);
p.add_instruction(pass_memory{}, p1, p2);
p.compile(memory_coloring_target{});
EXPECT(p.get_parameter_shape("scratch").bytes() == 672);
}
// This test uses the pass_memory operator with two memory allocation passed together.
// This is similar to allocations done for workspaces, that is one allocation is aliased and the
// other is just used
void test3()
{
migraph::program p;
auto a0 = p.add_outline(migraph::shape{migraph::shape::float_type, {8}});
auto a1 = p.add_instruction(allocate{}, a0);
auto a2 = p.add_outline(migraph::shape{migraph::shape::float_type, {128}});
auto p2 = p.add_instruction(allocate{}, a2);
auto p1 = p.add_instruction(pass_memory{}, a1, p2);
auto a3 = p.add_outline(migraph::shape{migraph::shape::float_type, {40}});
auto p3 = p.add_instruction(allocate{}, a3);
p.add_instruction(pass_memory{}, p1, p3);
p.compile(memory_coloring_target{});
EXPECT(p.get_parameter_shape("scratch").bytes() == 704);
}
// Like the previous test, but this tests a zero workspace memory allocation
void test4()
{
migraph::program p;
auto a0 = p.add_outline(migraph::shape{migraph::shape::float_type, {0}});
auto a1 = p.add_instruction(allocate{}, a0);
auto a2 = p.add_outline(migraph::shape{migraph::shape::float_type, {128}});
auto p2 = p.add_instruction(allocate{}, a2);
auto p1 = p.add_instruction(pass_memory{}, a1, p2);
auto a3 = p.add_outline(migraph::shape{migraph::shape::float_type, {40}});
auto p3 = p.add_instruction(allocate{}, a3);
p.add_instruction(pass_memory{}, p1, p3);
p.compile(memory_coloring_target{});
EXPECT(p.get_parameter_shape("scratch").bytes() == 672);
}
int main()
{
test1();
test2();
test3();
test4();
}
...@@ -13,9 +13,9 @@ void pytorch_conv_bias_test() ...@@ -13,9 +13,9 @@ void pytorch_conv_bias_test()
auto l1 = p.add_parameter("1", {migraph::shape::float_type, {1, 3, 5, 5}}); auto l1 = p.add_parameter("1", {migraph::shape::float_type, {1, 3, 5, 5}});
auto l2 = p.add_parameter("2", {migraph::shape::float_type, {1}}); auto l2 = p.add_parameter("2", {migraph::shape::float_type, {1}});
uint64_t axis = 1; uint64_t axis = 1;
auto l3 = p.add_instruction(migraph::convolution{}, l0, l1); auto l3 = p.add_instruction(migraph::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraph::broadcast{axis}, l3, l2); auto l4 = p.add_instruction(migraph::op::broadcast{axis}, l3, l2);
p.add_instruction(migraph::add{}, l3, l4); p.add_instruction(migraph::op::add{}, l3, l4);
auto prog = migraph::parse_onnx("conv.onnx"); auto prog = migraph::parse_onnx("conv.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -28,11 +28,11 @@ void pytorch_conv_relu_maxpool() ...@@ -28,11 +28,11 @@ void pytorch_conv_relu_maxpool()
auto l1 = p.add_parameter("1", {migraph::shape::float_type, {1, 3, 5, 5}}); auto l1 = p.add_parameter("1", {migraph::shape::float_type, {1, 3, 5, 5}});
auto l2 = p.add_parameter("2", {migraph::shape::float_type, {1}}); auto l2 = p.add_parameter("2", {migraph::shape::float_type, {1}});
uint64_t axis = 1; uint64_t axis = 1;
auto l3 = p.add_instruction(migraph::convolution{}, l0, l1); auto l3 = p.add_instruction(migraph::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraph::broadcast{axis}, l3, l2); auto l4 = p.add_instruction(migraph::op::broadcast{axis}, l3, l2);
auto l5 = p.add_instruction(migraph::add{}, l3, l4); auto l5 = p.add_instruction(migraph::op::add{}, l3, l4);
auto l6 = p.add_instruction(migraph::activation{"relu"}, l5); auto l6 = p.add_instruction(migraph::op::activation{"relu"}, l5);
p.add_instruction(migraph::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6); p.add_instruction(migraph::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6);
auto prog = migraph::parse_onnx("conv_relu_maxpool.onnx"); auto prog = migraph::parse_onnx("conv_relu_maxpool.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -50,12 +50,12 @@ void pytorch_conv_bn_relu_maxpool() ...@@ -50,12 +50,12 @@ void pytorch_conv_bn_relu_maxpool()
auto p5 = p.add_parameter("5", {migraph::shape::float_type, {1}}); auto p5 = p.add_parameter("5", {migraph::shape::float_type, {1}});
auto p6 = p.add_parameter("6", {migraph::shape::float_type, {1}}); auto p6 = p.add_parameter("6", {migraph::shape::float_type, {1}});
uint64_t axis = 1; uint64_t axis = 1;
auto l3 = p.add_instruction(migraph::convolution{}, l0, l1); auto l3 = p.add_instruction(migraph::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraph::broadcast{axis}, l3, l2); auto l4 = p.add_instruction(migraph::op::broadcast{axis}, l3, l2);
auto l5 = p.add_instruction(migraph::add{}, l3, l4); auto l5 = p.add_instruction(migraph::op::add{}, l3, l4);
auto l6 = p.add_instruction(migraph::batch_norm_inference{}, l5, p3, p4, p5, p6); auto l6 = p.add_instruction(migraph::op::batch_norm_inference{}, l5, p3, p4, p5, p6);
auto l7 = p.add_instruction(migraph::activation{"relu"}, l6); auto l7 = p.add_instruction(migraph::op::activation{"relu"}, l6);
p.add_instruction(migraph::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l7); p.add_instruction(migraph::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l7);
auto prog = migraph::parse_onnx("conv_bn_relu_maxpool.onnx"); auto prog = migraph::parse_onnx("conv_bn_relu_maxpool.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -68,19 +68,19 @@ void pytorch_conv_relu_maxpool_x2() ...@@ -68,19 +68,19 @@ void pytorch_conv_relu_maxpool_x2()
auto l1 = p.add_parameter("1", {migraph::shape::float_type, {5, 3, 5, 5}}); auto l1 = p.add_parameter("1", {migraph::shape::float_type, {5, 3, 5, 5}});
auto l2 = p.add_parameter("2", {migraph::shape::float_type, {5}}); auto l2 = p.add_parameter("2", {migraph::shape::float_type, {5}});
uint64_t axis = 1; uint64_t axis = 1;
auto l3 = p.add_instruction(migraph::convolution{}, l0, l1); auto l3 = p.add_instruction(migraph::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraph::broadcast{axis}, l3, l2); auto l4 = p.add_instruction(migraph::op::broadcast{axis}, l3, l2);
auto l5 = p.add_instruction(migraph::add{}, l3, l4); auto l5 = p.add_instruction(migraph::op::add{}, l3, l4);
auto l6 = p.add_instruction(migraph::activation{"relu"}, l5); auto l6 = p.add_instruction(migraph::op::activation{"relu"}, l5);
auto l7 = p.add_instruction(migraph::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6); auto l7 = p.add_instruction(migraph::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6);
auto l8 = p.add_parameter("3", {migraph::shape::float_type, {1, 5, 5, 5}}); auto l8 = p.add_parameter("3", {migraph::shape::float_type, {1, 5, 5, 5}});
auto l9 = p.add_parameter("4", {migraph::shape::float_type, {1}}); auto l9 = p.add_parameter("4", {migraph::shape::float_type, {1}});
auto l10 = p.add_instruction(migraph::convolution{}, l7, l8); auto l10 = p.add_instruction(migraph::op::convolution{}, l7, l8);
auto l11 = p.add_instruction(migraph::broadcast{axis}, l10, l9); auto l11 = p.add_instruction(migraph::op::broadcast{axis}, l10, l9);
auto l12 = p.add_instruction(migraph::add{}, l10, l11); auto l12 = p.add_instruction(migraph::op::add{}, l10, l11);
auto l13 = p.add_instruction(migraph::activation{"relu"}, l12); auto l13 = p.add_instruction(migraph::op::activation{"relu"}, l12);
p.add_instruction(migraph::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l13); p.add_instruction(migraph::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l13);
auto prog = migraph::parse_onnx("conv_relu_maxpoolX2.onnx"); auto prog = migraph::parse_onnx("conv_relu_maxpoolX2.onnx");
......
...@@ -57,9 +57,9 @@ void batch_norm_inference_shape() ...@@ -57,9 +57,9 @@ void batch_norm_inference_shape()
const size_t channels = 3; const size_t channels = 3;
migraph::shape s{migraph::shape::float_type, {4, channels, 3, 3}}; migraph::shape s{migraph::shape::float_type, {4, channels, 3, 3}};
migraph::shape vars{migraph::shape::float_type, {channels}}; migraph::shape vars{migraph::shape::float_type, {channels}};
expect_shape(s, migraph::batch_norm_inference{}, s, vars, vars, vars, vars); expect_shape(s, migraph::op::batch_norm_inference{}, s, vars, vars, vars, vars);
throws_shape(migraph::batch_norm_inference{}, s); throws_shape(migraph::op::batch_norm_inference{}, s);
throws_shape(migraph::batch_norm_inference{}, s, vars, vars, vars, vars, vars); throws_shape(migraph::op::batch_norm_inference{}, s, vars, vars, vars, vars, vars);
} }
void convolution_shape() void convolution_shape()
...@@ -67,33 +67,33 @@ void convolution_shape() ...@@ -67,33 +67,33 @@ void convolution_shape()
migraph::shape output{migraph::shape::float_type, {4, 4, 1, 1}}; migraph::shape output{migraph::shape::float_type, {4, 4, 1, 1}};
migraph::shape input{migraph::shape::float_type, {4, 3, 3, 3}}; migraph::shape input{migraph::shape::float_type, {4, 3, 3, 3}};
migraph::shape weights{migraph::shape::float_type, {4, 3, 3, 3}}; migraph::shape weights{migraph::shape::float_type, {4, 3, 3, 3}};
expect_shape(output, migraph::convolution{}, input, weights); expect_shape(output, migraph::op::convolution{}, input, weights);
throws_shape(migraph::convolution{}, input); throws_shape(migraph::op::convolution{}, input);
migraph::shape input2{migraph::shape::float_type, {3, 3}}; migraph::shape input2{migraph::shape::float_type, {3, 3}};
migraph::shape weights2{migraph::shape::float_type, {3, 3}}; migraph::shape weights2{migraph::shape::float_type, {3, 3}};
throws_shape(migraph::convolution{}, input2, weights2); throws_shape(migraph::op::convolution{}, input2, weights2);
throws_shape(migraph::convolution{}, input2, weights); throws_shape(migraph::op::convolution{}, input2, weights);
} }
void transpose_shape() void transpose_shape()
{ {
migraph::shape input{migraph::shape::float_type, {2, 2}}; migraph::shape input{migraph::shape::float_type, {2, 2}};
migraph::shape output{migraph::shape::float_type, {2, 2}, {1, 2}}; migraph::shape output{migraph::shape::float_type, {2, 2}, {1, 2}};
expect_shape(input, migraph::transpose{{0, 1}}, input); expect_shape(input, migraph::op::transpose{{0, 1}}, input);
expect_shape(output, migraph::transpose{{1, 0}}, input); expect_shape(output, migraph::op::transpose{{1, 0}}, input);
throws_shape(migraph::transpose{{1, 2}}, input); throws_shape(migraph::op::transpose{{1, 2}}, input);
} }
void contiguous_shape() void contiguous_shape()
{ {
migraph::shape output{migraph::shape::float_type, {2, 2}}; migraph::shape output{migraph::shape::float_type, {2, 2}};
migraph::shape input{migraph::shape::float_type, {2, 2}, {1, 2}}; migraph::shape input{migraph::shape::float_type, {2, 2}, {1, 2}};
expect_shape(output, migraph::contiguous{}, input); expect_shape(output, migraph::op::contiguous{}, input);
throws_shape(migraph::contiguous{}, input, input); throws_shape(migraph::op::contiguous{}, input, input);
migraph::shape single{migraph::shape::float_type, {2}}; migraph::shape single{migraph::shape::float_type, {2}};
throws_shape(migraph::contiguous{}, single); throws_shape(migraph::op::contiguous{}, single);
} }
void reshape_shape() void reshape_shape()
...@@ -105,31 +105,46 @@ void reshape_shape() ...@@ -105,31 +105,46 @@ void reshape_shape()
std::vector<std::size_t> lens(new_shape.size()); std::vector<std::size_t> lens(new_shape.size());
std::copy(new_shape.begin(), new_shape.end(), lens.begin()); std::copy(new_shape.begin(), new_shape.end(), lens.begin());
migraph::shape output{migraph::shape::float_type, lens}; migraph::shape output{migraph::shape::float_type, lens};
expect_shape(output, migraph::reshape{new_shape}, input); expect_shape(output, migraph::op::reshape{new_shape}, input);
} }
for(auto&& new_shape : std::vector<std::vector<int64_t>>{{8, 3, 2, 2}, {1, 3, -1, -1}}) for(auto&& new_shape : std::vector<std::vector<int64_t>>{{8, 3, 2, 2}, {1, 3, -1, -1}})
{ {
throws_shape(migraph::reshape{new_shape}, input); throws_shape(migraph::op::reshape{new_shape}, input);
} }
} }
void flatten_shape() void flatten_shape()
{ {
migraph::shape input{migraph::shape::float_type, {2, 4, 6, 8}}; migraph::shape input{migraph::shape::float_type, {2, 4, 6, 8}};
expect_shape(migraph::shape{migraph::shape::float_type, {1, 2 * 4 * 6 * 8}},
migraph::op::flatten{0},
input);
expect_shape( expect_shape(
migraph::shape{migraph::shape::float_type, {1, 2 * 4 * 6 * 8}}, migraph::flatten{0}, input); migraph::shape{migraph::shape::float_type, {2, 4 * 6 * 8}}, migraph::op::flatten{1}, input);
expect_shape( expect_shape(
migraph::shape{migraph::shape::float_type, {2, 4 * 6 * 8}}, migraph::flatten{1}, input); migraph::shape{migraph::shape::float_type, {2 * 4, 6 * 8}}, migraph::op::flatten{2}, input);
expect_shape( expect_shape(
migraph::shape{migraph::shape::float_type, {2 * 4, 6 * 8}}, migraph::flatten{2}, input); migraph::shape{migraph::shape::float_type, {2 * 4 * 6, 8}}, migraph::op::flatten{3}, input);
expect_shape( expect_shape(migraph::shape{migraph::shape::float_type, {2 * 4 * 6 * 8, 1}},
migraph::shape{migraph::shape::float_type, {2 * 4 * 6, 8}}, migraph::flatten{3}, input); migraph::op::flatten{4},
expect_shape( input);
migraph::shape{migraph::shape::float_type, {2 * 4 * 6 * 8, 1}}, migraph::flatten{4}, input); throws_shape(migraph::op::flatten{5}, input);
throws_shape(migraph::flatten{5}, input);
} }
void slice_shape()
{
migraph::shape input{migraph::shape::int32_type, {2, 2, 3}};
expect_shape(migraph::shape{migraph::shape::int32_type, {2, 2, 2}, {6, 3, 1}},
migraph::op::slice{{2}, {1}, {3}},
input);
expect_shape(migraph::shape{migraph::shape::int32_type, {2, 2, 2}, {6, 3, 1}},
migraph::op::slice{{0, 1, 2}, {0, 0, 1}, {2, 2, 3}},
input);
expect_shape(migraph::shape{migraph::shape::int32_type, {2, 2, 1}, {6, 3, 1}},
migraph::op::slice{{2}, {2}, {10}},
input);
}
int main() int main()
{ {
batch_norm_inference_shape(); batch_norm_inference_shape();
...@@ -138,4 +153,5 @@ int main() ...@@ -138,4 +153,5 @@ int main()
contiguous_shape(); contiguous_shape();
reshape_shape(); reshape_shape();
flatten_shape(); flatten_shape();
slice_shape();
} }
...@@ -18,9 +18,9 @@ void double_contig() ...@@ -18,9 +18,9 @@ void double_contig()
{ {
migraph::program p; migraph::program p;
auto l = p.add_literal(get_2x2()); auto l = p.add_literal(get_2x2());
auto t1 = p.add_instruction(migraph::transpose{{1, 0}}, l); auto t1 = p.add_instruction(migraph::op::transpose{{1, 0}}, l);
auto c1 = p.add_instruction(migraph::contiguous{}, t1); auto c1 = p.add_instruction(migraph::op::contiguous{}, t1);
auto c2 = p.add_instruction(migraph::contiguous{}, c1); auto c2 = p.add_instruction(migraph::op::contiguous{}, c1);
p.add_instruction(pass_op{}, c2); p.add_instruction(pass_op{}, c2);
EXPECT(p.get_shape().standard()); EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed()); EXPECT(not p.get_shape().transposed());
...@@ -36,8 +36,8 @@ void double_transpose() ...@@ -36,8 +36,8 @@ void double_transpose()
{ {
migraph::program p; migraph::program p;
auto l = p.add_literal(get_2x2()); auto l = p.add_literal(get_2x2());
auto t1 = p.add_instruction(migraph::transpose{{1, 0}}, l); auto t1 = p.add_instruction(migraph::op::transpose{{1, 0}}, l);
auto t2 = p.add_instruction(migraph::transpose{{1, 0}}, t1); auto t2 = p.add_instruction(migraph::op::transpose{{1, 0}}, t1);
p.add_instruction(pass_op{}, t2); p.add_instruction(pass_op{}, t2);
EXPECT(p.get_shape().standard()); EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed()); EXPECT(not p.get_shape().transposed());
...@@ -53,10 +53,10 @@ void double_transpose_contig() ...@@ -53,10 +53,10 @@ void double_transpose_contig()
{ {
migraph::program p; migraph::program p;
auto l = p.add_literal(get_2x2()); auto l = p.add_literal(get_2x2());
auto t1 = p.add_instruction(migraph::transpose{{1, 0}}, l); auto t1 = p.add_instruction(migraph::op::transpose{{1, 0}}, l);
auto c1 = p.add_instruction(migraph::contiguous{}, t1); auto c1 = p.add_instruction(migraph::op::contiguous{}, t1);
auto t2 = p.add_instruction(migraph::transpose{{1, 0}}, c1); auto t2 = p.add_instruction(migraph::op::transpose{{1, 0}}, c1);
auto c2 = p.add_instruction(migraph::contiguous{}, t2); auto c2 = p.add_instruction(migraph::op::contiguous{}, t2);
p.add_instruction(pass_op{}, c2); p.add_instruction(pass_op{}, c2);
EXPECT(p.get_shape().standard()); EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed()); EXPECT(not p.get_shape().transposed());
...@@ -72,7 +72,7 @@ void single_transpose() ...@@ -72,7 +72,7 @@ void single_transpose()
{ {
migraph::program p; migraph::program p;
auto l = p.add_literal(get_2x2()); auto l = p.add_literal(get_2x2());
auto t1 = p.add_instruction(migraph::transpose{{1, 0}}, l); auto t1 = p.add_instruction(migraph::op::transpose{{1, 0}}, l);
p.add_instruction(pass_op{}, t1); p.add_instruction(pass_op{}, t1);
EXPECT(not p.get_shape().standard()); EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().transposed()); EXPECT(p.get_shape().transposed());
...@@ -88,8 +88,8 @@ void double_transpose_sin_pass() ...@@ -88,8 +88,8 @@ void double_transpose_sin_pass()
{ {
migraph::program p; migraph::program p;
auto l = p.add_literal(get_2x2()); auto l = p.add_literal(get_2x2());
auto t1 = p.add_instruction(migraph::transpose{{1, 0}}, l); auto t1 = p.add_instruction(migraph::op::transpose{{1, 0}}, l);
p.add_instruction(migraph::transpose{{1, 0}}, t1); p.add_instruction(migraph::op::transpose{{1, 0}}, t1);
EXPECT(p.get_shape().standard()); EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed()); EXPECT(not p.get_shape().transposed());
p.compile(simplify_reshapes_target{}); p.compile(simplify_reshapes_target{});
...@@ -106,7 +106,7 @@ void single_transpose_sin_pass() ...@@ -106,7 +106,7 @@ void single_transpose_sin_pass()
{ {
migraph::program p; migraph::program p;
auto l = p.add_literal(get_2x2()); auto l = p.add_literal(get_2x2());
p.add_instruction(migraph::transpose{{1, 0}}, l); p.add_instruction(migraph::op::transpose{{1, 0}}, l);
EXPECT(not p.get_shape().standard()); EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().transposed()); EXPECT(p.get_shape().transposed());
p.compile(simplify_reshapes_target{}); p.compile(simplify_reshapes_target{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment