Commit c33b8a63 authored by Khalique's avatar Khalique
Browse files

Merge branch 'gemm_beta' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into gemm_beta

parents f0f5c4e8 37c8b12a
...@@ -74,7 +74,7 @@ ...@@ -74,7 +74,7 @@
</message> </message>
</rule> </rule>
<rule> <rule>
<pattern>(fclose|free|hipFree|hipHostFree|hipFreeArray|hipMemFree|hipStreamDestroy|hipEventDestroy|hipArrayDestroy|hipCtxDestroy|hipDestroyTextureObject|hipDestroySurfaceObject) \(</pattern> <pattern>\\W(fclose|free|hipFree|hipHostFree|hipFreeArray|hipMemFree|hipStreamDestroy|hipEventDestroy|hipArrayDestroy|hipCtxDestroy|hipDestroyTextureObject|hipDestroySurfaceObject) \(</pattern>
<message> <message>
<id>useManagePointer</id> <id>useManagePointer</id>
<severity>style</severity> <severity>style</severity>
......
...@@ -71,6 +71,8 @@ struct instruction ...@@ -71,6 +71,8 @@ struct instruction
static void static void
replace(instruction_ref ins, operation o, const shape& r, std::vector<instruction_ref> args); replace(instruction_ref ins, operation o, const shape& r, std::vector<instruction_ref> args);
argument eval() const;
static instruction_ref get_output_alias(instruction_ref ins); static instruction_ref get_output_alias(instruction_ref ins);
private: private:
......
...@@ -53,6 +53,9 @@ struct operation ...@@ -53,6 +53,9 @@ struct operation
friend std::ostream& operator<<(std::ostream& os, const operation& op); friend std::ostream& operator<<(std::ostream& os, const operation& op);
}; };
/// Returns true if operation does not require a context to run compute
bool is_context_free(const operation& x);
#else #else
namespace operation_stream { namespace operation_stream {
...@@ -89,7 +92,7 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name()) ...@@ -89,7 +92,7 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
} // namespace operation_equal } // namespace operation_equal
template <class T> template <class T>
auto compute_op(rank<1>, auto compute_op(rank<2>,
const T& x, const T& x,
context& ctx, context& ctx,
const shape& output_shape, const shape& output_shape,
...@@ -99,6 +102,14 @@ auto compute_op(rank<1>, ...@@ -99,6 +102,14 @@ auto compute_op(rank<1>,
return x.compute(auto_any_cast(ctx), output_shape, input); return x.compute(auto_any_cast(ctx), output_shape, input);
} }
template <class T>
auto compute_op(
rank<1>, const T& x, context&, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input))
{
return x.compute(output_shape, input);
}
template <class T> template <class T>
argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&) argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&)
{ {
...@@ -110,7 +121,53 @@ template <class T> ...@@ -110,7 +121,53 @@ template <class T>
argument argument
compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input) compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input)
{ {
return compute_op(rank<1>{}, x, ctx, output_shape, input); return compute_op(rank<2>{}, x, ctx, output_shape, input);
}
template <class T>
auto compute_op(rank<2>, const T& x, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input))
{
return x.compute(output_shape, input);
}
template <class T>
auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(auto_any_cast(std::declval<context&>()), output_shape, input))
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable without a context: " + name);
}
template <class T>
argument compute_op(rank<0>, const T& x, const shape&, const std::vector<argument>&)
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name);
}
template <class T>
argument compute_op(const T& x, const shape& output_shape, const std::vector<argument>& input)
{
return compute_op(rank<2>{}, x, output_shape, input);
}
template <class T>
auto is_context_free_op(rank<1>,
const T& x,
const shape& output_shape,
const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input), std::true_type{});
template <class T>
auto is_context_free_op(rank<0>, const T&, const shape&, const std::vector<argument>&)
-> std::false_type;
template <class T>
auto is_context_free_op(const T& x) -> decltype(is_context_free_op(
rank<1>{}, x, std::declval<const shape&>(), std::declval<std::vector<argument>>()))
{
return {};
} }
template <class T> template <class T>
...@@ -138,9 +195,11 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes) ...@@ -138,9 +195,11 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
* struct operation * struct operation
* { * {
* std::string name() const; * std::string name() const;
* bool is_context_free() const;
* int output_alias(const std::vector<shape>& input) const; * int output_alias(const std::vector<shape>& input) const;
* shape compute_shape(const std::vector<shape>& input) const; * shape compute_shape(const std::vector<shape>& input) const;
* argument compute(context& ctx,const shape& output,const std::vector<argument>& input) const; * argument compute(context& ctx,const shape& output,const std::vector<argument>& input) const;
* argument compute(const shape& output,const std::vector<argument>& input) const;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ; * friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* friend bool operator==(const operation & x,const operation & y) ; * friend bool operator==(const operation & x,const operation & y) ;
* }; * };
...@@ -210,6 +269,12 @@ struct operation ...@@ -210,6 +269,12 @@ struct operation
return (*this).private_detail_te_get_handle().name(); return (*this).private_detail_te_get_handle().name();
} }
bool is_context_free() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().is_context_free();
}
int output_alias(const std::vector<shape>& input) const int output_alias(const std::vector<shape>& input) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
...@@ -228,6 +293,12 @@ struct operation ...@@ -228,6 +293,12 @@ struct operation
return (*this).private_detail_te_get_handle().compute(ctx, output, input); return (*this).private_detail_te_get_handle().compute(ctx, output, input);
} }
argument compute(const shape& output, const std::vector<argument>& input) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().compute(output, input);
}
friend std::ostream& operator<<(std::ostream& os, const operation& op) friend std::ostream& operator<<(std::ostream& os, const operation& op)
{ {
assert(op.private_detail_te_handle_mem_var); assert(op.private_detail_te_handle_mem_var);
...@@ -248,12 +319,14 @@ struct operation ...@@ -248,12 +319,14 @@ struct operation
virtual const std::type_info& type() const = 0; virtual const std::type_info& type() const = 0;
virtual std::string name() const = 0; virtual std::string name() const = 0;
virtual bool is_context_free() const = 0;
virtual int output_alias(const std::vector<shape>& input) const = 0; virtual int output_alias(const std::vector<shape>& input) const = 0;
virtual shape compute_shape(const std::vector<shape>& input) const = 0; virtual shape compute_shape(const std::vector<shape>& input) const = 0;
virtual argument virtual argument
compute(context& ctx, const shape& output, const std::vector<argument>& input) const = 0; compute(context& ctx, const shape& output, const std::vector<argument>& input) const = 0;
virtual std::ostream& operator_shift_left(std::ostream& os) const = 0; virtual argument compute(const shape& output, const std::vector<argument>& input) const = 0;
virtual bool operator==(const operation& y) const = 0; virtual std::ostream& operator_shift_left(std::ostream& os) const = 0;
virtual bool operator==(const operation& y) const = 0;
}; };
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
...@@ -286,6 +359,12 @@ struct operation ...@@ -286,6 +359,12 @@ struct operation
std::string name() const override { return private_detail_te_value.name(); } std::string name() const override { return private_detail_te_value.name(); }
bool is_context_free() const override
{
return is_context_free_op(private_detail_te_value);
}
int output_alias(const std::vector<shape>& input) const override int output_alias(const std::vector<shape>& input) const override
{ {
...@@ -306,6 +385,12 @@ struct operation ...@@ -306,6 +385,12 @@ struct operation
return compute_op(private_detail_te_value, ctx, output, input); return compute_op(private_detail_te_value, ctx, output, input);
} }
argument compute(const shape& output, const std::vector<argument>& input) const override
{
return compute_op(private_detail_te_value, output, input);
}
std::ostream& operator_shift_left(std::ostream& os) const override std::ostream& operator_shift_left(std::ostream& os) const override
{ {
using migraphx::operation_stream::operator<<; using migraphx::operation_stream::operator<<;
...@@ -385,6 +470,14 @@ inline const ValueType& any_cast(const operation& x) ...@@ -385,6 +470,14 @@ inline const ValueType& any_cast(const operation& x)
inline bool operator!=(const operation& x, const operation& y) { return !(x == y); } inline bool operator!=(const operation& x, const operation& y) { return !(x == y); }
inline bool is_context_free(const operation& op) { return op.is_context_free(); }
template <class T>
bool is_context_free(const T& x)
{
return is_context_free_op(x);
}
#endif #endif
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -16,7 +16,7 @@ namespace op { ...@@ -16,7 +16,7 @@ namespace op {
struct not_computable struct not_computable
{ {
argument compute(context&, const shape&, const std::vector<argument>&) const argument compute(const shape&, const std::vector<argument>&) const
{ {
MIGRAPHX_THROW("not computable"); MIGRAPHX_THROW("not computable");
} }
...@@ -296,7 +296,7 @@ struct transpose ...@@ -296,7 +296,7 @@ struct transpose
} }
return {t, output_lens, output_strides}; return {t, output_lens, output_strides};
} }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.front().data)}; return {std::move(output_shape), std::move(args.front().data)};
} }
...@@ -370,6 +370,27 @@ struct concat ...@@ -370,6 +370,27 @@ struct concat
new_lens[axis] = new_dim_axis; new_lens[axis] = new_dim_axis;
return {type, new_lens}; return {type, new_lens};
} }
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
std::vector<std::size_t> coffsets = compute_offsets(output_shape, args);
for(std::size_t l = 0; l < args.size(); l++)
{
auto argl = args[l];
std::size_t nelements = argl.get_shape().elements();
visit_all(result, argl)([&](auto output, auto input) {
auto slice_shape =
shape{output_shape.type(), input.get_shape().lens(), output_shape.strides()};
auto slice = make_view(slice_shape, output.data() + coffsets[l]);
// cppcheck-suppress useStlAlgorithm
for(std::size_t i = 0; i < nelements; i++)
{
slice[i] = input[i];
}
});
}
return result;
}
int output_alias(const std::vector<shape>&) const { return 0; } int output_alias(const std::vector<shape>&) const { return 0; }
}; };
...@@ -437,7 +458,7 @@ struct slice ...@@ -437,7 +458,7 @@ struct slice
} }
return shape{t, new_lens, old_strides}; return shape{t, new_lens, old_strides};
} }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
auto input = args[0]; auto input = args[0];
auto offset = compute_offset(input.get_shape()) * output_shape.type_size(); auto offset = compute_offset(input.get_shape()) * output_shape.type_size();
...@@ -487,7 +508,7 @@ struct squeeze ...@@ -487,7 +508,7 @@ struct squeeze
} }
return shape{type, new_lens}; return shape{type, new_lens};
} }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.front().data)}; return {std::move(output_shape), std::move(args.front().data)};
} }
...@@ -526,7 +547,7 @@ struct unsqueeze ...@@ -526,7 +547,7 @@ struct unsqueeze
} }
return shape{type, new_lens}; return shape{type, new_lens};
} }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.front().data)}; return {std::move(output_shape), std::move(args.front().data)};
} }
...@@ -578,7 +599,7 @@ struct reshape ...@@ -578,7 +599,7 @@ struct reshape
MIGRAPHX_THROW("Wrong number of elements for reshape"); MIGRAPHX_THROW("Wrong number of elements for reshape");
return s; return s;
} }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.front().data)}; return {std::move(output_shape), std::move(args.front().data)};
} }
...@@ -624,7 +645,7 @@ struct identity ...@@ -624,7 +645,7 @@ struct identity
{ {
std::string name() const { return "identity"; } std::string name() const { return "identity"; }
shape compute_shape(std::vector<shape> inputs) const { return inputs.at(0); } shape compute_shape(std::vector<shape> inputs) const { return inputs.at(0); }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.at(0).data)}; return {std::move(output_shape), std::move(args.at(0).data)};
} }
...@@ -742,7 +763,7 @@ struct flatten ...@@ -742,7 +763,7 @@ struct flatten
std::accumulate(lens.begin() + axis, lens.end(), std::size_t{1}, std::multiplies<>{}); std::accumulate(lens.begin() + axis, lens.end(), std::size_t{1}, std::multiplies<>{});
return {inputs.at(0).type(), {x, y}}; return {inputs.at(0).type(), {x, y}};
} }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.front().data)}; return {std::move(output_shape), std::move(args.front().data)};
} }
...@@ -794,7 +815,7 @@ struct broadcast ...@@ -794,7 +815,7 @@ struct broadcast
return {t, broadcast_shape.lens(), std::move(bcast_strides)}; return {t, broadcast_shape.lens(), std::move(bcast_strides)};
} }
} }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.at(0).data)}; return {std::move(output_shape), std::move(args.at(0).data)};
} }
...@@ -836,7 +857,7 @@ struct multibroadcast ...@@ -836,7 +857,7 @@ struct multibroadcast
} }
return {t, output_lens, bcast_strides}; return {t, output_lens, bcast_strides};
} }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.at(0).data)}; return {std::move(output_shape), std::move(args.at(0).data)};
} }
...@@ -858,7 +879,7 @@ struct scalar ...@@ -858,7 +879,7 @@ struct scalar
return {t, scalar_bcast.lens(), strides}; return {t, scalar_bcast.lens(), strides};
} }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.at(0).data)}; return {std::move(output_shape), std::move(args.at(0).data)};
} }
...@@ -923,7 +944,7 @@ struct load ...@@ -923,7 +944,7 @@ struct load
check_shapes{inputs}.has(1); check_shapes{inputs}.has(1);
return s; return s;
} }
argument compute(context&, const shape&, const std::vector<argument>& args) const argument compute(const shape&, const std::vector<argument>& args) const
{ {
return {s, args[0].data() + offset}; return {s, args[0].data() + offset};
} }
...@@ -946,10 +967,7 @@ struct outline ...@@ -946,10 +967,7 @@ struct outline
check_shapes{inputs, *this}.has(0); check_shapes{inputs, *this}.has(0);
return s; return s;
} }
argument compute(context&, const shape&, const std::vector<argument>&) const argument compute(const shape&, const std::vector<argument>&) const { return {s, nullptr}; }
{
return {s, nullptr};
}
}; };
} // namespace op } // namespace op
......
...@@ -170,6 +170,27 @@ std::vector<shape> compute_shapes(const std::vector<instruction_ref>& args) ...@@ -170,6 +170,27 @@ std::vector<shape> compute_shapes(const std::vector<instruction_ref>& args)
return shapes; return shapes;
} }
argument instruction::eval() const
{
if(op.name() == "@literal")
{
return this->get_literal().get_argument();
}
if(is_context_free(op))
{
std::vector<argument> args;
for(auto&& arg : this->inputs())
{
argument a = arg->eval();
if(a.empty())
return {};
args.push_back(a);
}
return op.compute(result, args);
}
return {};
}
instruction_ref instruction::get_output_alias(instruction_ref ins) instruction_ref instruction::get_output_alias(instruction_ref ins)
{ {
auto i = ins->get_operator().output_alias(compute_shapes(ins->inputs())); auto i = ins->get_operator().output_alias(compute_shapes(ins->inputs()));
......
...@@ -299,24 +299,7 @@ struct cpu_concat ...@@ -299,24 +299,7 @@ struct cpu_concat
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; return op.compute(output_shape, std::move(args));
std::vector<std::size_t> coffsets = op.compute_offsets(output_shape, args);
for(std::size_t l = 0; l < args.size(); l++)
{
auto argl = args[l];
std::size_t nelements = argl.get_shape().elements();
visit_all(result, argl)([&](auto output, auto input) {
auto slice_shape =
shape{output_shape.type(), input.get_shape().lens(), output_shape.strides()};
auto slice = make_view(slice_shape, output.data() + coffsets[l]);
// cppcheck-suppress useStlAlgorithm
for(std::size_t i = 0; i < nelements; i++)
{
slice[i] = input[i];
}
});
}
return result;
} }
}; };
......
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <sstream>
#include "test.hpp"
#include <basic_ops.hpp>
struct sum_cf_op
{
std::string name() const { return "sum_cf"; }
migraphx::argument compute(const migraphx::shape&, std::vector<migraphx::argument> args) const
{
migraphx::argument result;
if(args.size() != 2)
MIGRAPHX_THROW("Wrong args");
if(args[0].get_shape() != args[1].get_shape())
MIGRAPHX_THROW("Wrong args");
if(args[0].get_shape().lens().size() != 1)
MIGRAPHX_THROW("Wrong args");
if(args[0].get_shape().lens().front() != 1)
MIGRAPHX_THROW("Wrong args");
args[0].visit_at([&](auto x) {
args[1].visit_at([&](auto y) { result = migraphx::literal{x + y}.get_argument(); });
});
return result;
}
migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
{
if(inputs.size() != 2)
MIGRAPHX_THROW("Wrong inputs");
return inputs.front();
}
};
struct non_computable_cf
{
std::string name() const { return "non_computable"; }
migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
{
if(inputs.empty())
return {};
return inputs.front();
}
};
struct test_context
{
void finish() const {}
};
TEST_CASE(literal_test)
{
migraphx::program p;
auto lit = p.add_literal(1);
CHECK(lit->eval() == migraphx::literal{1});
}
TEST_CASE(param_test)
{
migraphx::program p;
auto lit = p.add_parameter("param", migraphx::shape{migraphx::shape::float_type, {1}});
CHECK(lit->eval().empty());
}
TEST_CASE(op_test1)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_cf_op{}, one, two);
CHECK(sum->eval() == migraphx::literal{3});
}
TEST_CASE(op_test2)
{
migraphx::program p;
auto x = p.add_parameter("param", migraphx::shape{migraphx::shape::float_type, {1}});
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_cf_op{}, x, two);
CHECK(sum->eval().empty());
}
TEST_CASE(op_test3)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_cf_op{}, sum1, two);
CHECK(sum2->eval().empty());
}
TEST_CASE(compute_op_c)
{
migraphx::operation op = sum_op{};
auto one = migraphx::literal{1}.get_argument();
auto two = migraphx::literal{2}.get_argument();
EXPECT(test::throws([&] {
op.compute(migraphx::shape{migraphx::shape::float_type, {1}}, {one, two});
}));
}
TEST_CASE(compute_nop_c)
{
migraphx::operation op = non_computable_cf{};
auto one = migraphx::literal{1}.get_argument();
auto two = migraphx::literal{2}.get_argument();
EXPECT(test::throws([&] {
op.compute(migraphx::shape{migraphx::shape::float_type, {1}}, {one, two});
}));
}
TEST_CASE(compute_nop_context)
{
migraphx::operation op = non_computable_cf{};
auto one = migraphx::literal{1}.get_argument();
auto two = migraphx::literal{2}.get_argument();
migraphx::context ctx = test_context{};
EXPECT(test::throws([&] {
op.compute(ctx, migraphx::shape{migraphx::shape::float_type, {1}}, {one, two});
}));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -53,6 +53,9 @@ struct operation ...@@ -53,6 +53,9 @@ struct operation
friend std::ostream& operator<<(std::ostream& os, const operation& op); friend std::ostream& operator<<(std::ostream& os, const operation& op);
}; };
/// Returns true if operation does not require a context to run compute
bool is_context_free(const operation& x);
#else #else
namespace operation_stream { namespace operation_stream {
...@@ -89,7 +92,7 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name()) ...@@ -89,7 +92,7 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
} // namespace operation_equal } // namespace operation_equal
template <class T> template <class T>
auto compute_op(rank<1>, auto compute_op(rank<2>,
const T& x, const T& x,
context& ctx, context& ctx,
const shape& output_shape, const shape& output_shape,
...@@ -99,6 +102,14 @@ auto compute_op(rank<1>, ...@@ -99,6 +102,14 @@ auto compute_op(rank<1>,
return x.compute(auto_any_cast(ctx), output_shape, input); return x.compute(auto_any_cast(ctx), output_shape, input);
} }
template <class T>
auto compute_op(
rank<1>, const T& x, context&, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input))
{
return x.compute(output_shape, input);
}
template <class T> template <class T>
argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&) argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&)
{ {
...@@ -110,7 +121,53 @@ template <class T> ...@@ -110,7 +121,53 @@ template <class T>
argument argument
compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input) compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input)
{ {
return compute_op(rank<1>{}, x, ctx, output_shape, input); return compute_op(rank<2>{}, x, ctx, output_shape, input);
}
template <class T>
auto compute_op(rank<2>, const T& x, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input))
{
return x.compute(output_shape, input);
}
template <class T>
auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(auto_any_cast(std::declval<context&>()), output_shape, input))
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable without a context: " + name);
}
template <class T>
argument compute_op(rank<0>, const T& x, const shape&, const std::vector<argument>&)
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name);
}
template <class T>
argument compute_op(const T& x, const shape& output_shape, const std::vector<argument>& input)
{
return compute_op(rank<2>{}, x, output_shape, input);
}
template <class T>
auto is_context_free_op(rank<1>,
const T& x,
const shape& output_shape,
const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input), std::true_type{});
template <class T>
auto is_context_free_op(rank<0>, const T&, const shape&, const std::vector<argument>&)
-> std::false_type;
template <class T>
auto is_context_free_op(const T& x) -> decltype(is_context_free_op(
rank<1>{}, x, std::declval<const shape&>(), std::declval<std::vector<argument>>()))
{
return {};
} }
template <class T> template <class T>
...@@ -136,6 +193,7 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes) ...@@ -136,6 +193,7 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
interface( interface(
'operation', 'operation',
virtual('name', returns = 'std::string', const = True), virtual('name', returns = 'std::string', const = True),
virtual('is_context_free', returns = 'bool', const = True, default = 'is_context_free_op'),
virtual('output_alias', virtual('output_alias',
returns = 'int', returns = 'int',
input = 'const std::vector<shape>&', input = 'const std::vector<shape>&',
...@@ -149,6 +207,12 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes) ...@@ -149,6 +207,12 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
input = 'const std::vector<argument>&', input = 'const std::vector<argument>&',
const = True, const = True,
default = 'compute_op'), default = 'compute_op'),
virtual('compute',
returns = 'argument',
output = 'const shape&',
input = 'const std::vector<argument>&',
const = True,
default = 'compute_op'),
friend('operator<<', friend('operator<<',
returns = 'std::ostream &', returns = 'std::ostream &',
os = 'std::ostream &', os = 'std::ostream &',
...@@ -165,6 +229,14 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes) ...@@ -165,6 +229,14 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
return !(x == y); return !(x == y);
} }
inline bool is_context_free(const operation& op) { return op.is_context_free(); }
template <class T>
bool is_context_free(const T& x)
{
return is_context_free_op(x);
}
#endif #endif
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment