"src/vscode:/vscode.git/clone" did not exist on "e209869c42223a709f192db3b1c7c38c35576f0b"
Commit 5dfeb457 authored by Scott Thornton's avatar Scott Thornton
Browse files

Merge branch 'master' into squeeze_unsqeeze

parents 523a78c7 f9f4f713
...@@ -579,17 +579,17 @@ struct cpu_apply ...@@ -579,17 +579,17 @@ struct cpu_apply
init(); init();
for(auto it : iterator_for(*prog)) for(auto it : iterator_for(*prog))
{ {
if(it->op.name() == "activation") if(it->name() == "activation")
{ {
apply_activation(it); apply_activation(it);
} }
else if(it->op.name() == "pooling") else if(it->name() == "pooling")
{ {
apply_pooling(it); apply_pooling(it);
} }
else if(apply_map.count(it->op.name()) > 0) else if(apply_map.count(it->name()) > 0)
{ {
apply_map.at(it->op.name())(it); apply_map.at(it->name())(it);
} }
} }
} }
...@@ -597,30 +597,30 @@ struct cpu_apply ...@@ -597,30 +597,30 @@ struct cpu_apply
template <class T> template <class T>
void apply_simple_op(instruction_ref ins) void apply_simple_op(instruction_ref ins)
{ {
prog->replace_instruction(ins, T{}, ins->arguments); prog->replace_instruction(ins, T{}, ins->inputs());
} }
template <class T, class Op> template <class T, class Op>
void apply_extend_op(instruction_ref ins) void apply_extend_op(instruction_ref ins)
{ {
auto&& op = any_cast<Op>(ins->op); auto&& op = any_cast<Op>(ins->get_operator());
prog->replace_instruction(ins, T{op}, ins->arguments); prog->replace_instruction(ins, T{op}, ins->inputs());
} }
void apply_activation(instruction_ref ins) void apply_activation(instruction_ref ins)
{ {
auto&& op = any_cast<activation>(ins->op); auto&& op = any_cast<activation>(ins->get_operator());
if(op.mode == "relu") if(op.mode == "relu")
prog->replace_instruction(ins, cpu_unary<relu_op>{}, ins->arguments); prog->replace_instruction(ins, cpu_unary<relu_op>{}, ins->inputs());
} }
void apply_pooling(instruction_ref ins) void apply_pooling(instruction_ref ins)
{ {
auto&& op = any_cast<pooling>(ins->op); auto&& op = any_cast<pooling>(ins->get_operator());
if(op.mode == "max") if(op.mode == "max")
prog->replace_instruction(ins, cpu_pooling<max_pool>{op}, ins->arguments); prog->replace_instruction(ins, cpu_pooling<max_pool>{op}, ins->inputs());
else if(op.mode == "average") else if(op.mode == "average")
prog->replace_instruction(ins, cpu_pooling<avg_pool>{op}, ins->arguments); prog->replace_instruction(ins, cpu_pooling<avg_pool>{op}, ins->inputs());
} }
}; };
......
...@@ -16,11 +16,11 @@ void eliminate_workspace::apply(program& p) const ...@@ -16,11 +16,11 @@ void eliminate_workspace::apply(program& p) const
std::vector<instruction_ref> allocs; std::vector<instruction_ref> allocs;
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
if(ins->output.size() != 1) if(ins->outputs().size() != 1)
continue; continue;
if(ins->op.name() != "hip::allocate") if(ins->name() != "hip::allocate")
continue; continue;
auto&& a = any_cast<hip_allocate>(ins->op); auto&& a = any_cast<hip_allocate>(ins->get_operator());
if(a.tag == "workspace") if(a.tag == "workspace")
{ {
n = std::max(n, ins->get_shape().bytes()); n = std::max(n, ins->get_shape().bytes());
......
...@@ -26,14 +26,14 @@ void fuse_ops::apply(program& p) const ...@@ -26,14 +26,14 @@ void fuse_ops::apply(program& p) const
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
if(ins->op.name() != "gpu::relu") if(ins->name() != "gpu::relu")
continue; continue;
auto add_ins = ins->arguments.front(); auto add_ins = ins->inputs().front();
if(add_ins->op.name() != "gpu::add") if(add_ins->name() != "gpu::add")
continue; continue;
auto args = add_ins->arguments; auto args = add_ins->inputs();
// Use the allocation from the relu operator // Use the allocation from the relu operator
args.back() = ins->arguments.back(); args.back() = ins->inputs().back();
p.replace_instruction(ins, hip_add_relu{}, args); p.replace_instruction(ins, hip_add_relu{}, args);
} }
} }
......
...@@ -129,8 +129,16 @@ struct miopen_convolution ...@@ -129,8 +129,16 @@ struct miopen_convolution
workspace_size, workspace_size,
false); false);
algo = perf.fwd_algo; algo = perf.fwd_algo;
return algo == miopenConvolutionFwdAlgoWinograd ? shape{shape::int8_type, {0}} return shape{shape::int8_type, {perf.memory}};
: workspace_shape; }
friend std::ostream& operator<<(std::ostream& os, const miopen_convolution& self)
{
os << self.name() << "[";
os << self.op << ", ";
os << "algo=" << self.algo;
os << "]";
return os;
} }
}; };
...@@ -305,6 +313,34 @@ struct miopen_relu ...@@ -305,6 +313,34 @@ struct miopen_relu
} }
}; };
struct miopen_softmax
{
softmax op;
std::string name() const { return "gpu::softmax"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2).standard();
return op.compute_shape({inputs.at(0)});
}
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
{
float alpha = 1, beta = 0;
auto x_desc = make_tensor(args[0].get_shape());
auto y_desc = make_tensor(output_shape);
miopenSoftmaxForward(ctx.handle.get(),
&alpha,
x_desc.get(),
args[0].implicit(),
&beta,
y_desc.get(),
args[1].implicit());
return args[1];
}
};
struct miopen_apply struct miopen_apply
{ {
program* prog = nullptr; program* prog = nullptr;
...@@ -322,34 +358,38 @@ struct miopen_apply ...@@ -322,34 +358,38 @@ struct miopen_apply
for(auto it = prog->begin(); it != prog->end(); it++) for(auto it = prog->begin(); it != prog->end(); it++)
{ {
auto s = it->get_shape(); auto s = it->get_shape();
if(it->op.name() == "convolution") if(it->name() == "convolution")
{ {
check_shape(s, apply_convolution(it)); check_shape(s, apply_convolution(it));
} }
else if(it->op.name() == "activation") else if(it->name() == "activation")
{ {
check_shape(s, apply_activation(it)); check_shape(s, apply_activation(it));
} }
else if(it->op.name() == "pooling") else if(it->name() == "pooling")
{ {
check_shape(s, apply_pooling(it)); check_shape(s, apply_pooling(it));
} }
else if(it->op.name() == "add") else if(it->name() == "add")
{ {
check_shape(s, apply_add(it)); check_shape(s, apply_add(it));
} }
else if(it->op.name() == "gemm") else if(it->name() == "gemm")
{ {
check_shape(s, apply_gemm(it)); check_shape(s, apply_gemm(it));
} }
else if(it->op.name() == "contiguous") else if(it->name() == "contiguous")
{ {
check_shape(s, apply_contiguous(it)); check_shape(s, apply_contiguous(it));
} }
else if(it->op.name() == "batch_norm_inference") else if(it->name() == "batch_norm_inference")
{ {
check_shape(s, apply_batch_norm_inference(it)); check_shape(s, apply_batch_norm_inference(it));
} }
else if(it->name() == "softmax")
{
check_shape(s, apply_softmax(it));
}
} }
} }
...@@ -369,78 +409,85 @@ struct miopen_apply ...@@ -369,78 +409,85 @@ struct miopen_apply
instruction_ref apply_convolution(instruction_ref ins) instruction_ref apply_convolution(instruction_ref ins)
{ {
auto&& op = any_cast<convolution>(ins->op); auto&& op = any_cast<convolution>(ins->get_operator());
auto conv = miopen_convolution{op, make_conv(op)}; auto conv = miopen_convolution{op, make_conv(op)};
auto ws = conv.compile(ctx, ins->result, ins->arguments); auto ws = conv.compile(ctx, ins->get_shape(), ins->inputs());
auto workspace = insert_allocation(ins, ws, "workspace"); auto workspace = insert_allocation(ins, ws, "workspace");
auto output = insert_allocation(ins, ins->result); auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction( return prog->replace_instruction(
ins, conv, ins->arguments.at(0), ins->arguments.at(1), workspace, output); ins, conv, ins->inputs().at(0), ins->inputs().at(1), workspace, output);
} }
instruction_ref apply_pooling(instruction_ref ins) instruction_ref apply_pooling(instruction_ref ins)
{ {
auto&& op = any_cast<pooling>(ins->op); auto&& op = any_cast<pooling>(ins->get_operator());
auto pd = make_pooling(op); auto pd = make_pooling(op);
auto output = insert_allocation(ins, ins->result); auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction( return prog->replace_instruction(
ins, miopen_pooling{op, std::move(pd)}, ins->arguments.at(0), output); ins, miopen_pooling{op, std::move(pd)}, ins->inputs().at(0), output);
} }
instruction_ref apply_activation(instruction_ref ins) instruction_ref apply_activation(instruction_ref ins)
{ {
auto&& op = any_cast<activation>(ins->op); auto&& op = any_cast<activation>(ins->get_operator());
auto ad = make_relu(); auto ad = make_relu();
if(op.mode == "relu") if(op.mode == "relu")
{ {
auto output = insert_allocation(ins, ins->result); auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction( return prog->replace_instruction(
ins, miopen_relu{std::move(ad)}, ins->arguments.at(0), output); ins, miopen_relu{std::move(ad)}, ins->inputs().at(0), output);
} }
return ins; return ins;
} }
instruction_ref apply_softmax(instruction_ref ins)
{
auto&& op = any_cast<softmax>(ins->get_operator());
auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction(ins, miopen_softmax{op}, ins->inputs().at(0), output);
}
instruction_ref apply_add(instruction_ref ins) instruction_ref apply_add(instruction_ref ins)
{ {
auto output = insert_allocation(ins, ins->result); auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction( return prog->replace_instruction(
ins, hip_add{}, ins->arguments.at(0), ins->arguments.at(1), output); ins, hip_add{}, ins->inputs().at(0), ins->inputs().at(1), output);
} }
instruction_ref apply_gemm(instruction_ref ins) instruction_ref apply_gemm(instruction_ref ins)
{ {
auto&& op = any_cast<gemm>(ins->op); auto&& op = any_cast<gemm>(ins->get_operator());
auto output = insert_allocation(ins, ins->result); auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction( return prog->replace_instruction(
ins, miopen_gemm{op}, ins->arguments.at(0), ins->arguments.at(1), output); ins, miopen_gemm{op}, ins->inputs().at(0), ins->inputs().at(1), output);
} }
instruction_ref apply_contiguous(instruction_ref ins) instruction_ref apply_contiguous(instruction_ref ins)
{ {
auto&& op = any_cast<contiguous>(ins->op); auto&& op = any_cast<contiguous>(ins->get_operator());
auto output = insert_allocation(ins, ins->result); auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction(ins, miopen_contiguous{op}, ins->arguments.at(0), output); return prog->replace_instruction(ins, miopen_contiguous{op}, ins->inputs().at(0), output);
} }
instruction_ref apply_batch_norm_inference(instruction_ref ins) instruction_ref apply_batch_norm_inference(instruction_ref ins)
{ {
auto&& op = any_cast<batch_norm_inference>(ins->op); auto&& op = any_cast<batch_norm_inference>(ins->get_operator());
auto output = insert_allocation(ins, ins->result); auto output = insert_allocation(ins, ins->get_shape());
shape old_shape = ins->arguments.at(1)->get_shape(); shape old_shape = ins->inputs().at(1)->get_shape();
std::vector<int64_t> new_shape{1, static_cast<int64_t>(old_shape.elements()), 1, 1}; std::vector<int64_t> new_shape{1, static_cast<int64_t>(old_shape.elements()), 1, 1};
auto reshape_op = reshape{new_shape}; auto reshape_op = reshape{new_shape};
std::vector<instruction_ref> reshapes; std::vector<instruction_ref> reshapes;
std::transform(ins->arguments.begin() + 1, std::transform(ins->inputs().begin() + 1,
ins->arguments.end(), ins->inputs().end(),
std::back_inserter(reshapes), std::back_inserter(reshapes),
[&](auto i) { return prog->insert_instruction(ins, reshape_op, i); }); [&](auto i) { return prog->insert_instruction(ins, reshape_op, i); });
return prog->replace_instruction(ins, return prog->replace_instruction(ins,
miopen_batch_norm_inference{op}, miopen_batch_norm_inference{op},
ins->arguments.at(0), ins->inputs().at(0),
reshapes[0], reshapes[0],
reshapes[1], reshapes[1],
reshapes[2], reshapes[2],
......
...@@ -30,11 +30,11 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const ...@@ -30,11 +30,11 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const
lowering{ctx}, lowering{ctx},
fuse_ops{}, fuse_ops{},
dead_code_elimination{}, dead_code_elimination{},
eliminate_workspace{},
eliminate_contiguous{}, eliminate_contiguous{},
dead_code_elimination{}, dead_code_elimination{},
write_literals{&ctx}, write_literals{&ctx},
eliminate_allocation{""}, eliminate_workspace{},
eliminate_allocation{"hip::allocate"},
check_context<context>{}, check_context<context>{},
dead_code_elimination{} dead_code_elimination{}
}; };
......
...@@ -28,9 +28,9 @@ void write_literals::apply(program& p) const ...@@ -28,9 +28,9 @@ void write_literals::apply(program& p) const
assert(ctx != nullptr); assert(ctx != nullptr);
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
if(ins->op.name() == "@literal") if(ins->name() == "@literal")
{ {
argument a = to_gpu(ins->lit.get_argument()); argument a = to_gpu(ins->get_literal().get_argument());
std::size_t n = ctx->literals.size(); std::size_t n = ctx->literals.size();
ctx->literals.push_back(a); ctx->literals.push_back(a);
p.replace_instruction(ins, hip_load_literal{a.get_shape(), n}); p.replace_instruction(ins, hip_load_literal{a.get_shape(), n});
......
...@@ -21,13 +21,13 @@ struct reverse_pass ...@@ -21,13 +21,13 @@ struct reverse_pass
{ {
for(auto ins : migraph::iterator_for(p)) for(auto ins : migraph::iterator_for(p))
{ {
if(ins->op.name() == "sum") if(ins->name() == "sum")
{ {
p.replace_instruction(ins, minus_op{}, ins->arguments); p.replace_instruction(ins, minus_op{}, ins->inputs());
} }
else if(ins->op.name() == "minus") else if(ins->name() == "minus")
{ {
p.replace_instruction(ins, sum_op{}, ins->arguments); p.replace_instruction(ins, sum_op{}, ins->inputs());
} }
} }
} }
......
...@@ -97,10 +97,10 @@ void compile_check(migraph::program& p, const migraph::target& t) ...@@ -97,10 +97,10 @@ void compile_check(migraph::program& p, const migraph::target& t)
} }
template <class V> template <class V>
migraph::argument run_cpu() migraph::argument run_cpu(migraph::program& p)
{ {
V v; V v;
auto p = v.create_program(); p = v.create_program();
auto_print pp{p, 0}; auto_print pp{p, 0};
compile_check(p, migraph::cpu::cpu_target{}); compile_check(p, migraph::cpu::cpu_target{});
migraph::program::parameter_map m; migraph::program::parameter_map m;
...@@ -112,10 +112,10 @@ migraph::argument run_cpu() ...@@ -112,10 +112,10 @@ migraph::argument run_cpu()
} }
template <class V> template <class V>
migraph::argument run_gpu() migraph::argument run_gpu(migraph::program& p)
{ {
V v; V v;
auto p = v.create_program(); p = v.create_program();
auto_print pp{p, 1}; auto_print pp{p, 1};
compile_check(p, migraph::gpu::target{}); compile_check(p, migraph::gpu::target{});
migraph::program::parameter_map m; migraph::program::parameter_map m;
...@@ -131,9 +131,20 @@ template <class V> ...@@ -131,9 +131,20 @@ template <class V>
void verify_program() void verify_program()
{ {
auto_print::set_terminate_handler(migraph::get_type_name<V>()); auto_print::set_terminate_handler(migraph::get_type_name<V>());
auto cpu_arg_f = detach_async([] { return run_cpu<V>(); }); migraph::program cpu_prog;
auto gpu_arg = run_gpu<V>(); migraph::program gpu_prog;
verify_args(migraph::get_type_name<V>(), cpu_arg_f.get(), gpu_arg); auto cpu_arg_f = detach_async([&] { return run_cpu<V>(cpu_prog); });
auto gpu_arg = run_gpu<V>(gpu_prog);
bool passed = verify_args(migraph::get_type_name<V>(), cpu_arg_f.get(), gpu_arg);
if(not passed)
{
V v;
auto p = v.create_program();
std::cout << p << std::endl;
std::cout << "cpu:\n" << cpu_prog << std::endl;
std::cout << "gpu:\n" << gpu_prog << std::endl;
std::cout << std::endl;
}
std::set_terminate(nullptr); std::set_terminate(nullptr);
} }
...@@ -235,6 +246,28 @@ struct test_add_broadcast5 ...@@ -235,6 +246,28 @@ struct test_add_broadcast5
} }
}; };
struct test_softmax
{
migraph::program create_program() const
{
migraph::program p;
auto x = p.add_parameter("x", migraph::shape{migraph::shape::float_type, {5, 3, 4, 2}});
p.add_instruction(migraph::softmax{}, x);
return p;
}
};
struct test_softmax2
{
migraph::program create_program() const
{
migraph::program p;
auto x = p.add_parameter("x", migraph::shape{migraph::shape::float_type, {1, 1000, 1, 1}});
p.add_instruction(migraph::softmax{}, x);
return p;
}
};
struct test_conv struct test_conv
{ {
migraph::program create_program() const migraph::program create_program() const
...@@ -248,6 +281,20 @@ struct test_conv ...@@ -248,6 +281,20 @@ struct test_conv
} }
}; };
struct test_conv2
{
migraph::program create_program() const
{
migraph::program p;
auto input =
p.add_parameter("x", migraph::shape{migraph::shape::float_type, {1, 512, 28, 28}});
auto weights =
p.add_parameter("w", migraph::shape{migraph::shape::float_type, {256, 512, 1, 1}});
p.add_instruction(migraph::convolution{{0, 0}, {1, 1}, {1, 1}}, input, weights);
return p;
}
};
struct test_conv_relu struct test_conv_relu
{ {
migraph::program create_program() const migraph::program create_program() const
...@@ -428,6 +475,27 @@ struct test_batchnorm_inference ...@@ -428,6 +475,27 @@ struct test_batchnorm_inference
} }
}; };
struct test_conv_bn
{
migraph::program create_program() const
{
migraph::program p;
migraph::shape xs{migraph::shape::float_type, {1, 3, 224, 224}};
migraph::shape ws{migraph::shape::float_type, {64, 3, 7, 7}};
migraph::shape vars{migraph::shape::float_type, {64}};
auto x = p.add_parameter("x", xs);
auto w = p.add_parameter("w", ws);
auto conv = p.add_instruction(migraph::convolution{{3, 3}, {2, 2}, {1, 1}}, x, w);
auto scale = p.add_literal(migraph::abs(migraph::generate_literal(vars, 1)));
auto bias = p.add_literal(migraph::abs(migraph::generate_literal(vars, 2)));
auto mean = p.add_literal(migraph::abs(migraph::generate_literal(vars, 3)));
auto variance = p.add_literal(migraph::abs(migraph::generate_literal(vars, 4)));
p.add_instruction(migraph::batch_norm_inference{}, conv, scale, bias, mean, variance);
return p;
}
};
struct test_conv_bn_relu_pooling struct test_conv_bn_relu_pooling
{ {
migraph::program create_program() const migraph::program create_program() const
...@@ -495,7 +563,10 @@ int main() ...@@ -495,7 +563,10 @@ int main()
verify_program<test_add_broadcast3>(); verify_program<test_add_broadcast3>();
verify_program<test_add_broadcast4>(); verify_program<test_add_broadcast4>();
verify_program<test_add_broadcast5>(); verify_program<test_add_broadcast5>();
verify_program<test_softmax>();
verify_program<test_softmax2>();
verify_program<test_conv>(); verify_program<test_conv>();
verify_program<test_conv2>();
verify_program<test_conv_relu>(); verify_program<test_conv_relu>();
verify_program<test_add_relu>(); verify_program<test_add_relu>();
verify_program<test_conv_pooling>(); verify_program<test_conv_pooling>();
...@@ -508,6 +579,7 @@ int main() ...@@ -508,6 +579,7 @@ int main()
verify_program<test_transpose>(); verify_program<test_transpose>();
verify_program<test_batchnorm_inference>(); verify_program<test_batchnorm_inference>();
verify_program<test_batchnorm_inference_2>(); verify_program<test_batchnorm_inference_2>();
verify_program<test_conv_bn>();
verify_program<test_conv_bn_relu_pooling>(); verify_program<test_conv_bn_relu_pooling>();
verify_program<test_conv_bn_relu_pooling2>(); verify_program<test_conv_bn_relu_pooling2>();
} }
#ifndef MIGRAPH_GUARD_ROB_HPP
#define MIGRAPH_GUARD_ROB_HPP
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wglobal-constructors"
#endif
// Used to access private member variables
template <class Tag>
struct stowed
{
static typename Tag::type value;
};
template <class Tag>
typename Tag::type stowed<Tag>::value;
template <class Tag, typename Tag::type X>
struct stow_private
{
stow_private() noexcept { stowed<Tag>::value = X; }
static stow_private instance;
};
template <class Tag, typename Tag::type X>
stow_private<Tag, X> stow_private<Tag, X>::instance;
template <class C, class T>
struct mem_data_ptr
{
using type = T C::*;
};
#define MIGRAPH_ROB(name, Type, C, mem) \
struct name##_tag : mem_data_ptr<C, Type> \
{ \
}; \
template struct stow_private<name##_tag, &C::mem>; \
template <class T> \
auto& name(T&& x) \
{ \
return x.*stowed<name##_tag>::value; \
}
#ifdef __clang__
#pragma clang diagnostic pop
#endif
#endif
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <migraph/instruction.hpp> #include <migraph/instruction.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <test.hpp> #include <test.hpp>
#include <rob.hpp>
void simple_test() void simple_test()
{ {
...@@ -38,6 +39,11 @@ void incomplete_args() ...@@ -38,6 +39,11 @@ void incomplete_args()
EXPECT(bool{p.validate() == ins}); EXPECT(bool{p.validate() == ins});
} }
MIGRAPH_ROB(access_ins_arguments,
std::vector<migraph::instruction_ref>,
migraph::instruction,
arguments)
void invalid_args() void invalid_args()
{ {
migraph::program p; migraph::program p;
...@@ -45,7 +51,7 @@ void invalid_args() ...@@ -45,7 +51,7 @@ void invalid_args()
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto ins = p.add_instruction(sum_op{}, one, two); auto ins = p.add_instruction(sum_op{}, one, two);
ins->arguments.clear(); access_ins_arguments(*ins).clear();
EXPECT(bool{p.validate() == p.begin()}); EXPECT(bool{p.validate() == p.begin()});
} }
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include <migraph/shape.hpp> #include <migraph/shape.hpp>
#include <migraph/rank.hpp>
#include <migraph/argument.hpp> #include <migraph/argument.hpp>
#include <migraph/context.hpp> #include <migraph/context.hpp>
#include <migraph/auto_any_cast.hpp> #include <migraph/auto_any_cast.hpp>
...@@ -27,13 +28,16 @@ struct operation ...@@ -27,13 +28,16 @@ struct operation
/// exception. /// exception.
shape compute_shape(const std::vector<shape>& input) const; shape compute_shape(const std::vector<shape>& input) const;
/** /**
* @brief This performs the operation's computation * @brief This performs the operation's computation.
*
* This method can be optional when the operation is only used as a placeholder to be lowered
* later on.
* *
* @param ctx This is the context created by the `target` during compilation. Implementations * @param ctx This is the context created by the `target` during compilation. Implementations
* can use the target's `context` class rather than the `context` interface class. * can use the target's `context` class rather than the `context` interface class.
* @param output This is the output shape. It is equivalent to running `compute_shape` with each * @param output This is the output shape. It is equivalent to running `compute_shape` with each
* `shape` of the `argument`. * `shape` of the `argument`.
* @param input This is the `argument` result from the previous instuction's computation. * @param input This is the `argument` result from the previous instruction's computation.
* @return Return an `argument` of the result computation. The `shape` of `argument` should be * @return Return an `argument` of the result computation. The `shape` of `argument` should be
* the same the `output` shape. * the same the `output` shape.
*/ */
...@@ -55,11 +59,29 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name()) ...@@ -55,11 +59,29 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
} // namespace operation_stream } // namespace operation_stream
template <class T>
auto compute_op(rank<1>,
const T& x,
context& ctx,
const shape& output_shape,
const std::vector<argument>& input)
-> decltype(x.compute(auto_any_cast(ctx), output_shape, input))
{
return x.compute(auto_any_cast(ctx), output_shape, input);
}
template <class T>
argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&)
{
std::string name = x.name();
MIGRAPH_THROW("Not computable: " + name);
}
template <class T> template <class T>
argument argument
compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input) compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input)
{ {
return x.compute(auto_any_cast(ctx), output_shape, input); return compute_op(rank<1>{}, x, ctx, output_shape, input);
} }
<% <%
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment