Commit 09ed8ee2 authored by Khalique's avatar Khalique
Browse files

Merge branch 'gemm_beta' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into group_conv

parents 637aa811 c33b8a63
...@@ -74,7 +74,7 @@ ...@@ -74,7 +74,7 @@
</message> </message>
</rule> </rule>
<rule> <rule>
<pattern>(fclose|free|hipFree|hipHostFree|hipFreeArray|hipMemFree|hipStreamDestroy|hipEventDestroy|hipArrayDestroy|hipCtxDestroy|hipDestroyTextureObject|hipDestroySurfaceObject) \(</pattern> <pattern>\\W(fclose|free|hipFree|hipHostFree|hipFreeArray|hipMemFree|hipStreamDestroy|hipEventDestroy|hipArrayDestroy|hipCtxDestroy|hipDestroyTextureObject|hipDestroySurfaceObject) \(</pattern>
<message> <message>
<id>useManagePointer</id> <id>useManagePointer</id>
<severity>style</severity> <severity>style</severity>
......
...@@ -71,6 +71,8 @@ struct instruction ...@@ -71,6 +71,8 @@ struct instruction
static void static void
replace(instruction_ref ins, operation o, const shape& r, std::vector<instruction_ref> args); replace(instruction_ref ins, operation o, const shape& r, std::vector<instruction_ref> args);
argument eval() const;
static instruction_ref get_output_alias(instruction_ref ins); static instruction_ref get_output_alias(instruction_ref ins);
private: private:
......
...@@ -53,6 +53,9 @@ struct operation ...@@ -53,6 +53,9 @@ struct operation
friend std::ostream& operator<<(std::ostream& os, const operation& op); friend std::ostream& operator<<(std::ostream& os, const operation& op);
}; };
/// Returns true if operation does not require a context to run compute
bool is_context_free(const operation& x);
#else #else
namespace operation_stream { namespace operation_stream {
...@@ -89,7 +92,7 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name()) ...@@ -89,7 +92,7 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
} // namespace operation_equal } // namespace operation_equal
template <class T> template <class T>
auto compute_op(rank<1>, auto compute_op(rank<2>,
const T& x, const T& x,
context& ctx, context& ctx,
const shape& output_shape, const shape& output_shape,
...@@ -99,6 +102,14 @@ auto compute_op(rank<1>, ...@@ -99,6 +102,14 @@ auto compute_op(rank<1>,
return x.compute(auto_any_cast(ctx), output_shape, input); return x.compute(auto_any_cast(ctx), output_shape, input);
} }
template <class T>
auto compute_op(
rank<1>, const T& x, context&, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input))
{
return x.compute(output_shape, input);
}
template <class T> template <class T>
argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&) argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&)
{ {
...@@ -110,7 +121,53 @@ template <class T> ...@@ -110,7 +121,53 @@ template <class T>
argument argument
compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input) compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input)
{ {
return compute_op(rank<1>{}, x, ctx, output_shape, input); return compute_op(rank<2>{}, x, ctx, output_shape, input);
}
template <class T>
auto compute_op(rank<2>, const T& x, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input))
{
return x.compute(output_shape, input);
}
template <class T>
auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(auto_any_cast(std::declval<context&>()), output_shape, input))
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable without a context: " + name);
}
template <class T>
argument compute_op(rank<0>, const T& x, const shape&, const std::vector<argument>&)
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name);
}
template <class T>
argument compute_op(const T& x, const shape& output_shape, const std::vector<argument>& input)
{
return compute_op(rank<2>{}, x, output_shape, input);
}
template <class T>
auto is_context_free_op(rank<1>,
const T& x,
const shape& output_shape,
const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input), std::true_type{});
template <class T>
auto is_context_free_op(rank<0>, const T&, const shape&, const std::vector<argument>&)
-> std::false_type;
template <class T>
auto is_context_free_op(const T& x) -> decltype(is_context_free_op(
rank<1>{}, x, std::declval<const shape&>(), std::declval<std::vector<argument>>()))
{
return {};
} }
template <class T> template <class T>
...@@ -138,9 +195,11 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes) ...@@ -138,9 +195,11 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
* struct operation * struct operation
* { * {
* std::string name() const; * std::string name() const;
* bool is_context_free() const;
* int output_alias(const std::vector<shape>& input) const; * int output_alias(const std::vector<shape>& input) const;
* shape compute_shape(const std::vector<shape>& input) const; * shape compute_shape(const std::vector<shape>& input) const;
* argument compute(context& ctx,const shape& output,const std::vector<argument>& input) const; * argument compute(context& ctx,const shape& output,const std::vector<argument>& input) const;
* argument compute(const shape& output,const std::vector<argument>& input) const;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ; * friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* friend bool operator==(const operation & x,const operation & y) ; * friend bool operator==(const operation & x,const operation & y) ;
* }; * };
...@@ -210,6 +269,12 @@ struct operation ...@@ -210,6 +269,12 @@ struct operation
return (*this).private_detail_te_get_handle().name(); return (*this).private_detail_te_get_handle().name();
} }
bool is_context_free() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().is_context_free();
}
int output_alias(const std::vector<shape>& input) const int output_alias(const std::vector<shape>& input) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
...@@ -228,6 +293,12 @@ struct operation ...@@ -228,6 +293,12 @@ struct operation
return (*this).private_detail_te_get_handle().compute(ctx, output, input); return (*this).private_detail_te_get_handle().compute(ctx, output, input);
} }
argument compute(const shape& output, const std::vector<argument>& input) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().compute(output, input);
}
friend std::ostream& operator<<(std::ostream& os, const operation& op) friend std::ostream& operator<<(std::ostream& os, const operation& op)
{ {
assert(op.private_detail_te_handle_mem_var); assert(op.private_detail_te_handle_mem_var);
...@@ -248,10 +319,12 @@ struct operation ...@@ -248,10 +319,12 @@ struct operation
virtual const std::type_info& type() const = 0; virtual const std::type_info& type() const = 0;
virtual std::string name() const = 0; virtual std::string name() const = 0;
virtual bool is_context_free() const = 0;
virtual int output_alias(const std::vector<shape>& input) const = 0; virtual int output_alias(const std::vector<shape>& input) const = 0;
virtual shape compute_shape(const std::vector<shape>& input) const = 0; virtual shape compute_shape(const std::vector<shape>& input) const = 0;
virtual argument virtual argument
compute(context& ctx, const shape& output, const std::vector<argument>& input) const = 0; compute(context& ctx, const shape& output, const std::vector<argument>& input) const = 0;
virtual argument compute(const shape& output, const std::vector<argument>& input) const = 0;
virtual std::ostream& operator_shift_left(std::ostream& os) const = 0; virtual std::ostream& operator_shift_left(std::ostream& os) const = 0;
virtual bool operator==(const operation& y) const = 0; virtual bool operator==(const operation& y) const = 0;
}; };
...@@ -286,6 +359,12 @@ struct operation ...@@ -286,6 +359,12 @@ struct operation
std::string name() const override { return private_detail_te_value.name(); } std::string name() const override { return private_detail_te_value.name(); }
bool is_context_free() const override
{
return is_context_free_op(private_detail_te_value);
}
int output_alias(const std::vector<shape>& input) const override int output_alias(const std::vector<shape>& input) const override
{ {
...@@ -306,6 +385,12 @@ struct operation ...@@ -306,6 +385,12 @@ struct operation
return compute_op(private_detail_te_value, ctx, output, input); return compute_op(private_detail_te_value, ctx, output, input);
} }
argument compute(const shape& output, const std::vector<argument>& input) const override
{
return compute_op(private_detail_te_value, output, input);
}
std::ostream& operator_shift_left(std::ostream& os) const override std::ostream& operator_shift_left(std::ostream& os) const override
{ {
using migraphx::operation_stream::operator<<; using migraphx::operation_stream::operator<<;
...@@ -385,6 +470,14 @@ inline const ValueType& any_cast(const operation& x) ...@@ -385,6 +470,14 @@ inline const ValueType& any_cast(const operation& x)
inline bool operator!=(const operation& x, const operation& y) { return !(x == y); } inline bool operator!=(const operation& x, const operation& y) { return !(x == y); }
inline bool is_context_free(const operation& op) { return op.is_context_free(); }
template <class T>
bool is_context_free(const T& x)
{
return is_context_free_op(x);
}
#endif #endif
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -16,7 +16,7 @@ namespace op { ...@@ -16,7 +16,7 @@ namespace op {
struct not_computable struct not_computable
{ {
argument compute(context&, const shape&, const std::vector<argument>&) const argument compute(const shape&, const std::vector<argument>&) const
{ {
MIGRAPHX_THROW("not computable"); MIGRAPHX_THROW("not computable");
} }
...@@ -298,7 +298,7 @@ struct transpose ...@@ -298,7 +298,7 @@ struct transpose
} }
return {t, output_lens, output_strides}; return {t, output_lens, output_strides};
} }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.front().data)}; return {std::move(output_shape), std::move(args.front().data)};
} }
...@@ -372,6 +372,27 @@ struct concat ...@@ -372,6 +372,27 @@ struct concat
new_lens[axis] = new_dim_axis; new_lens[axis] = new_dim_axis;
return {type, new_lens}; return {type, new_lens};
} }
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
std::vector<std::size_t> coffsets = compute_offsets(output_shape, args);
for(std::size_t l = 0; l < args.size(); l++)
{
auto argl = args[l];
std::size_t nelements = argl.get_shape().elements();
visit_all(result, argl)([&](auto output, auto input) {
auto slice_shape =
shape{output_shape.type(), input.get_shape().lens(), output_shape.strides()};
auto slice = make_view(slice_shape, output.data() + coffsets[l]);
// cppcheck-suppress useStlAlgorithm
for(std::size_t i = 0; i < nelements; i++)
{
slice[i] = input[i];
}
});
}
return result;
}
int output_alias(const std::vector<shape>&) const { return 0; } int output_alias(const std::vector<shape>&) const { return 0; }
}; };
...@@ -439,7 +460,7 @@ struct slice ...@@ -439,7 +460,7 @@ struct slice
} }
return shape{t, new_lens, old_strides}; return shape{t, new_lens, old_strides};
} }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
auto input = args[0]; auto input = args[0];
auto offset = compute_offset(input.get_shape()) * output_shape.type_size(); auto offset = compute_offset(input.get_shape()) * output_shape.type_size();
...@@ -489,7 +510,7 @@ struct squeeze ...@@ -489,7 +510,7 @@ struct squeeze
} }
return shape{type, new_lens}; return shape{type, new_lens};
} }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.front().data)}; return {std::move(output_shape), std::move(args.front().data)};
} }
...@@ -528,7 +549,7 @@ struct unsqueeze ...@@ -528,7 +549,7 @@ struct unsqueeze
} }
return shape{type, new_lens}; return shape{type, new_lens};
} }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.front().data)}; return {std::move(output_shape), std::move(args.front().data)};
} }
...@@ -580,7 +601,7 @@ struct reshape ...@@ -580,7 +601,7 @@ struct reshape
MIGRAPHX_THROW("Wrong number of elements for reshape"); MIGRAPHX_THROW("Wrong number of elements for reshape");
return s; return s;
} }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.front().data)}; return {std::move(output_shape), std::move(args.front().data)};
} }
...@@ -626,7 +647,7 @@ struct identity ...@@ -626,7 +647,7 @@ struct identity
{ {
std::string name() const { return "identity"; } std::string name() const { return "identity"; }
shape compute_shape(std::vector<shape> inputs) const { return inputs.at(0); } shape compute_shape(std::vector<shape> inputs) const { return inputs.at(0); }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.at(0).data)}; return {std::move(output_shape), std::move(args.at(0).data)};
} }
...@@ -744,7 +765,7 @@ struct flatten ...@@ -744,7 +765,7 @@ struct flatten
std::accumulate(lens.begin() + axis, lens.end(), std::size_t{1}, std::multiplies<>{}); std::accumulate(lens.begin() + axis, lens.end(), std::size_t{1}, std::multiplies<>{});
return {inputs.at(0).type(), {x, y}}; return {inputs.at(0).type(), {x, y}};
} }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.front().data)}; return {std::move(output_shape), std::move(args.front().data)};
} }
...@@ -796,7 +817,7 @@ struct broadcast ...@@ -796,7 +817,7 @@ struct broadcast
return {t, broadcast_shape.lens(), std::move(bcast_strides)}; return {t, broadcast_shape.lens(), std::move(bcast_strides)};
} }
} }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.at(0).data)}; return {std::move(output_shape), std::move(args.at(0).data)};
} }
...@@ -838,7 +859,7 @@ struct multibroadcast ...@@ -838,7 +859,7 @@ struct multibroadcast
} }
return {t, output_lens, bcast_strides}; return {t, output_lens, bcast_strides};
} }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.at(0).data)}; return {std::move(output_shape), std::move(args.at(0).data)};
} }
...@@ -860,7 +881,7 @@ struct scalar ...@@ -860,7 +881,7 @@ struct scalar
return {t, scalar_bcast.lens(), strides}; return {t, scalar_bcast.lens(), strides};
} }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.at(0).data)}; return {std::move(output_shape), std::move(args.at(0).data)};
} }
...@@ -925,7 +946,7 @@ struct load ...@@ -925,7 +946,7 @@ struct load
check_shapes{inputs}.has(1); check_shapes{inputs}.has(1);
return s; return s;
} }
argument compute(context&, const shape&, const std::vector<argument>& args) const argument compute(const shape&, const std::vector<argument>& args) const
{ {
return {s, args[0].data() + offset}; return {s, args[0].data() + offset};
} }
...@@ -948,10 +969,7 @@ struct outline ...@@ -948,10 +969,7 @@ struct outline
check_shapes{inputs, *this}.has(0); check_shapes{inputs, *this}.has(0);
return s; return s;
} }
argument compute(context&, const shape&, const std::vector<argument>&) const argument compute(const shape&, const std::vector<argument>&) const { return {s, nullptr}; }
{
return {s, nullptr};
}
}; };
} // namespace op } // namespace op
......
...@@ -170,6 +170,27 @@ std::vector<shape> compute_shapes(const std::vector<instruction_ref>& args) ...@@ -170,6 +170,27 @@ std::vector<shape> compute_shapes(const std::vector<instruction_ref>& args)
return shapes; return shapes;
} }
argument instruction::eval() const
{
if(op.name() == "@literal")
{
return this->get_literal().get_argument();
}
if(is_context_free(op))
{
std::vector<argument> args;
for(auto&& arg : this->inputs())
{
argument a = arg->eval();
if(a.empty())
return {};
args.push_back(a);
}
return op.compute(result, args);
}
return {};
}
instruction_ref instruction::get_output_alias(instruction_ref ins) instruction_ref instruction::get_output_alias(instruction_ref ins)
{ {
auto i = ins->get_operator().output_alias(compute_shapes(ins->inputs())); auto i = ins->get_operator().output_alias(compute_shapes(ins->inputs()));
......
...@@ -150,7 +150,7 @@ struct onnx_parser ...@@ -150,7 +150,7 @@ struct onnx_parser
if(s0->size() > s1->size()) if(s0->size() > s1->size())
std::swap(s0, s1); std::swap(s0, s1);
std::vector<std::size_t> output_lens(s1->size()); std::vector<std::size_t> output_lens(*s1);
auto offset = s1->size() - s0->size(); auto offset = s1->size() - s0->size();
std::transform(s0->begin(), std::transform(s0->begin(),
s0->end(), s0->end(),
...@@ -388,7 +388,7 @@ struct onnx_parser ...@@ -388,7 +388,7 @@ struct onnx_parser
parse_gemm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_gemm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
float alpha = 1.0f; float alpha = 1.0f;
float beta = 0.0f; float beta = 1.0f;
bool transa = false; bool transa = false;
bool transb = false; bool transb = false;
if(contains(attributes, "alpha")) if(contains(attributes, "alpha"))
...@@ -412,10 +412,20 @@ struct onnx_parser ...@@ -412,10 +412,20 @@ struct onnx_parser
auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1]; auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1];
if(args.size() == 3) if(args.size() == 3)
{ {
uint64_t axis = 1; if(beta != 0.f)
auto l3 = prog.add_instruction(op::dot{alpha, beta}, l1, l2); {
auto l4 = prog.add_instruction(op::broadcast{axis, l3->get_shape()}, args[2]); auto l3 = prog.add_instruction(op::dot{alpha}, l1, l2);
return prog.add_instruction(op::add{}, l3, l4); auto l4 = args[2];
if(l4->get_shape().scalar()) // ignore args[2] (no C value added to alpha*A*B)
return l3;
if(beta != 1.f)
{
auto beta_val = prog.add_literal(beta);
auto l5 = prog.add_instruction(op::scalar{args[2]->get_shape()}, beta_val);
l4 = prog.add_instruction(op::mul{}, args[2], l5);
}
return add_broadcastable_binary_op(l3, l4, op::add{});
}
} }
return prog.add_instruction(op::dot{alpha, beta}, l1, l2); return prog.add_instruction(op::dot{alpha, beta}, l1, l2);
} }
......
...@@ -304,24 +304,7 @@ struct cpu_concat ...@@ -304,24 +304,7 @@ struct cpu_concat
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; return op.compute(output_shape, std::move(args));
std::vector<std::size_t> coffsets = op.compute_offsets(output_shape, args);
for(std::size_t l = 0; l < args.size(); l++)
{
auto argl = args[l];
std::size_t nelements = argl.get_shape().elements();
visit_all(result, argl)([&](auto output, auto input) {
auto slice_shape =
shape{output_shape.type(), input.get_shape().lens(), output_shape.strides()};
auto slice = make_view(slice_shape, output.data() + coffsets[l]);
// cppcheck-suppress useStlAlgorithm
for(std::size_t i = 0; i < nelements; i++)
{
slice[i] = input[i];
}
});
}
return result;
} }
}; };
......
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <sstream>
#include "test.hpp"
#include <basic_ops.hpp>
struct sum_cf_op
{
std::string name() const { return "sum_cf"; }
migraphx::argument compute(const migraphx::shape&, std::vector<migraphx::argument> args) const
{
migraphx::argument result;
if(args.size() != 2)
MIGRAPHX_THROW("Wrong args");
if(args[0].get_shape() != args[1].get_shape())
MIGRAPHX_THROW("Wrong args");
if(args[0].get_shape().lens().size() != 1)
MIGRAPHX_THROW("Wrong args");
if(args[0].get_shape().lens().front() != 1)
MIGRAPHX_THROW("Wrong args");
args[0].visit_at([&](auto x) {
args[1].visit_at([&](auto y) { result = migraphx::literal{x + y}.get_argument(); });
});
return result;
}
migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
{
if(inputs.size() != 2)
MIGRAPHX_THROW("Wrong inputs");
return inputs.front();
}
};
struct non_computable_cf
{
std::string name() const { return "non_computable"; }
migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
{
if(inputs.empty())
return {};
return inputs.front();
}
};
struct test_context
{
void finish() const {}
};
TEST_CASE(literal_test)
{
migraphx::program p;
auto lit = p.add_literal(1);
CHECK(lit->eval() == migraphx::literal{1});
}
TEST_CASE(param_test)
{
migraphx::program p;
auto lit = p.add_parameter("param", migraphx::shape{migraphx::shape::float_type, {1}});
CHECK(lit->eval().empty());
}
TEST_CASE(op_test1)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_cf_op{}, one, two);
CHECK(sum->eval() == migraphx::literal{3});
}
TEST_CASE(op_test2)
{
migraphx::program p;
auto x = p.add_parameter("param", migraphx::shape{migraphx::shape::float_type, {1}});
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_cf_op{}, x, two);
CHECK(sum->eval().empty());
}
TEST_CASE(op_test3)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_cf_op{}, sum1, two);
CHECK(sum2->eval().empty());
}
TEST_CASE(compute_op_c)
{
migraphx::operation op = sum_op{};
auto one = migraphx::literal{1}.get_argument();
auto two = migraphx::literal{2}.get_argument();
EXPECT(test::throws([&] {
op.compute(migraphx::shape{migraphx::shape::float_type, {1}}, {one, two});
}));
}
TEST_CASE(compute_nop_c)
{
migraphx::operation op = non_computable_cf{};
auto one = migraphx::literal{1}.get_argument();
auto two = migraphx::literal{2}.get_argument();
EXPECT(test::throws([&] {
op.compute(migraphx::shape{migraphx::shape::float_type, {1}}, {one, two});
}));
}
TEST_CASE(compute_nop_context)
{
migraphx::operation op = non_computable_cf{};
auto one = migraphx::literal{1}.get_argument();
auto two = migraphx::literal{2}.get_argument();
migraphx::context ctx = test_context{};
EXPECT(test::throws([&] {
op.compute(ctx, migraphx::shape{migraphx::shape::float_type, {1}}, {one, two});
}));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -351,8 +351,8 @@ TEST_CASE(implicit_bcast_test) ...@@ -351,8 +351,8 @@ TEST_CASE(implicit_bcast_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
auto l2 = p.add_instruction(migraphx::op::multibroadcast{{0, 0, 4, 5}}, l0); auto l2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0);
auto l3 = p.add_instruction(migraphx::op::multibroadcast{{0, 0, 4, 5}}, l1); auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::add{}, l2, l3); p.add_instruction(migraphx::op::add{}, l2, l3);
auto prog = migraphx::parse_onnx("implicit_bcast_test.onnx"); auto prog = migraphx::parse_onnx("implicit_bcast_test.onnx");
...@@ -460,12 +460,11 @@ TEST_CASE(gemm_test) ...@@ -460,12 +460,11 @@ TEST_CASE(gemm_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 7}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 7}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {11, 5}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {11, 5}});
auto l2 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {}}); p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {}});
auto t0 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l0); auto t0 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l0);
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1); auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto d0 = p.add_instruction(migraphx::op::dot{2, 2}, t0, t1); auto alpha = 2.f;
auto b0 = p.add_instruction(migraphx::op::broadcast{1, d0->get_shape()}, l2); p.add_instruction(migraphx::op::dot{alpha}, t0, t1);
p.add_instruction(migraphx::op::add{}, d0, b0);
auto prog = migraphx::parse_onnx("gemm_test.onnx"); auto prog = migraphx::parse_onnx("gemm_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -477,8 +476,8 @@ TEST_CASE(add_scalar_test) ...@@ -477,8 +476,8 @@ TEST_CASE(add_scalar_test)
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = auto l1 =
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}}, {1}}); p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}}, {1}});
auto m0 = p.add_instruction(migraphx::op::multibroadcast{{0, 0, 0, 5}}, l0); auto m0 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0);
auto m1 = p.add_instruction(migraphx::op::multibroadcast{{0, 0, 0, 5}}, l1); auto m1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::add{}, m0, m1); p.add_instruction(migraphx::op::add{}, m0, m1);
auto prog = migraphx::parse_onnx("add_scalar_test.onnx"); auto prog = migraphx::parse_onnx("add_scalar_test.onnx");
......
...@@ -53,6 +53,9 @@ struct operation ...@@ -53,6 +53,9 @@ struct operation
friend std::ostream& operator<<(std::ostream& os, const operation& op); friend std::ostream& operator<<(std::ostream& os, const operation& op);
}; };
/// Returns true if operation does not require a context to run compute
bool is_context_free(const operation& x);
#else #else
namespace operation_stream { namespace operation_stream {
...@@ -89,7 +92,7 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name()) ...@@ -89,7 +92,7 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
} // namespace operation_equal } // namespace operation_equal
template <class T> template <class T>
auto compute_op(rank<1>, auto compute_op(rank<2>,
const T& x, const T& x,
context& ctx, context& ctx,
const shape& output_shape, const shape& output_shape,
...@@ -99,6 +102,14 @@ auto compute_op(rank<1>, ...@@ -99,6 +102,14 @@ auto compute_op(rank<1>,
return x.compute(auto_any_cast(ctx), output_shape, input); return x.compute(auto_any_cast(ctx), output_shape, input);
} }
template <class T>
auto compute_op(
rank<1>, const T& x, context&, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input))
{
return x.compute(output_shape, input);
}
template <class T> template <class T>
argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&) argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&)
{ {
...@@ -110,7 +121,53 @@ template <class T> ...@@ -110,7 +121,53 @@ template <class T>
argument argument
compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input) compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input)
{ {
return compute_op(rank<1>{}, x, ctx, output_shape, input); return compute_op(rank<2>{}, x, ctx, output_shape, input);
}
template <class T>
auto compute_op(rank<2>, const T& x, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input))
{
return x.compute(output_shape, input);
}
template <class T>
auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(auto_any_cast(std::declval<context&>()), output_shape, input))
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable without a context: " + name);
}
template <class T>
argument compute_op(rank<0>, const T& x, const shape&, const std::vector<argument>&)
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name);
}
template <class T>
argument compute_op(const T& x, const shape& output_shape, const std::vector<argument>& input)
{
return compute_op(rank<2>{}, x, output_shape, input);
}
template <class T>
auto is_context_free_op(rank<1>,
const T& x,
const shape& output_shape,
const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input), std::true_type{});
template <class T>
auto is_context_free_op(rank<0>, const T&, const shape&, const std::vector<argument>&)
-> std::false_type;
template <class T>
auto is_context_free_op(const T& x) -> decltype(is_context_free_op(
rank<1>{}, x, std::declval<const shape&>(), std::declval<std::vector<argument>>()))
{
return {};
} }
template <class T> template <class T>
...@@ -136,6 +193,7 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes) ...@@ -136,6 +193,7 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
interface( interface(
'operation', 'operation',
virtual('name', returns = 'std::string', const = True), virtual('name', returns = 'std::string', const = True),
virtual('is_context_free', returns = 'bool', const = True, default = 'is_context_free_op'),
virtual('output_alias', virtual('output_alias',
returns = 'int', returns = 'int',
input = 'const std::vector<shape>&', input = 'const std::vector<shape>&',
...@@ -149,6 +207,12 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes) ...@@ -149,6 +207,12 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
input = 'const std::vector<argument>&', input = 'const std::vector<argument>&',
const = True, const = True,
default = 'compute_op'), default = 'compute_op'),
virtual('compute',
returns = 'argument',
output = 'const shape&',
input = 'const std::vector<argument>&',
const = True,
default = 'compute_op'),
friend('operator<<', friend('operator<<',
returns = 'std::ostream &', returns = 'std::ostream &',
os = 'std::ostream &', os = 'std::ostream &',
...@@ -165,6 +229,14 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes) ...@@ -165,6 +229,14 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
return !(x == y); return !(x == y);
} }
inline bool is_context_free(const operation& op) { return op.is_context_free(); }
template <class T>
bool is_context_free(const T& x)
{
return is_context_free_op(x);
}
#endif #endif
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment