Commit 43ae3419 authored by Paul's avatar Paul
Browse files

Add context parameter

parent 7591b7ff
pfultz2/rocm-recipes pfultz2/rocm-recipes
danmar/cppcheck@d9f9bdda7344e80585f71141be7797055d7987f3 danmar/cppcheck@d9f9bdda7344e80585f71141be7797055d7987f3
# python/cpython@v3.6.6 -X autotools -H sha256:92aa914572c695c0aeb01b0a214813f414da4b51a371234df514a74761f2bb36
-f requirements.txt -f requirements.txt
#ifndef RTG_GUARD_BUILTIN_HPP #ifndef RTG_GUARD_BUILTIN_HPP
#define RTG_GUARD_BUILTIN_HPP #define RTG_GUARD_BUILTIN_HPP
#include <rtg/operation.hpp> #include <rtg/context.hpp>
#include <rtg/errors.hpp> #include <rtg/errors.hpp>
namespace rtg { namespace rtg {
...@@ -12,7 +12,7 @@ struct literal ...@@ -12,7 +12,7 @@ struct literal
{ {
std::string name() const { return "@literal"; } std::string name() const { return "@literal"; }
shape compute_shape(std::vector<shape>) const { RTG_THROW("builtin"); } shape compute_shape(std::vector<shape>) const { RTG_THROW("builtin"); }
argument compute(shape, std::vector<argument>) const { RTG_THROW("builtin"); } argument compute(context&, shape, std::vector<argument>) const { RTG_THROW("builtin"); }
}; };
struct outline struct outline
...@@ -20,7 +20,7 @@ struct outline ...@@ -20,7 +20,7 @@ struct outline
shape s; shape s;
std::string name() const { return "@outline"; } std::string name() const { return "@outline"; }
shape compute_shape(std::vector<shape>) const { RTG_THROW("builtin"); } shape compute_shape(std::vector<shape>) const { RTG_THROW("builtin"); }
argument compute(shape, std::vector<argument>) const { RTG_THROW("builtin"); } argument compute(context&, shape, std::vector<argument>) const { RTG_THROW("builtin"); }
}; };
struct param struct param
...@@ -28,7 +28,7 @@ struct param ...@@ -28,7 +28,7 @@ struct param
std::string parameter; std::string parameter;
std::string name() const { return "@param"; } std::string name() const { return "@param"; }
shape compute_shape(std::vector<shape>) const { RTG_THROW("builtin"); } shape compute_shape(std::vector<shape>) const { RTG_THROW("builtin"); }
argument compute(shape, std::vector<argument>) const { RTG_THROW("builtin"); } argument compute(context&, shape, std::vector<argument>) const { RTG_THROW("builtin"); }
friend std::ostream& operator<<(std::ostream& os, const param& op) friend std::ostream& operator<<(std::ostream& os, const param& op)
{ {
os << op.name() << ":" << op.parameter; os << op.name() << ":" << op.parameter;
......
#ifndef RTG_GUARD_CONTEXT_HPP
#define RTG_GUARD_CONTEXT_HPP
namespace rtg {
/*
* Type-erased interface for:
*
* struct context
* {
* };
*
*/
struct context
{
// Constructors
context() = default;
template <typename PrivateDetailTypeErasedT>
context(PrivateDetailTypeErasedT value)
: private_detail_te_handle_mem_var(
std::make_shared<private_detail_te_handle_type<
typename std::remove_reference<PrivateDetailTypeErasedT>::type>>(
std::forward<PrivateDetailTypeErasedT>(value)))
{
}
// Assignment
template <typename PrivateDetailTypeErasedT>
context& operator=(PrivateDetailTypeErasedT value)
{
if(private_detail_te_handle_mem_var.unique())
*private_detail_te_handle_mem_var = std::forward<PrivateDetailTypeErasedT>(value);
else if(!private_detail_te_handle_mem_var)
private_detail_te_handle_mem_var = std::make_shared<PrivateDetailTypeErasedT>(
std::forward<PrivateDetailTypeErasedT>(value));
return *this;
}
// Cast
template <typename PrivateDetailTypeErasedT>
PrivateDetailTypeErasedT* any_cast()
{
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
.private_detail_te_value)
: nullptr;
}
template <typename PrivateDetailTypeErasedT>
const typename std::remove_cv<PrivateDetailTypeErasedT>::type* any_cast() const
{
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<const private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
.private_detail_te_value)
: nullptr;
}
private:
struct private_detail_te_handle_base_type
{
virtual ~private_detail_te_handle_base_type() {}
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0;
};
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type
{
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<std::is_reference<PrivateDetailTypeErasedU>::value>::type* =
nullptr)
: private_detail_te_value(value)
{
}
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value))
{
}
std::shared_ptr<private_detail_te_handle_base_type> clone() const override
{
return std::make_shared<private_detail_te_handle_type>(private_detail_te_value);
}
const std::type_info& type() const override { return typeid(private_detail_te_value); }
PrivateDetailTypeErasedT private_detail_te_value;
};
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type<std::reference_wrapper<PrivateDetailTypeErasedT>>
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>
{
private_detail_te_handle_type(std::reference_wrapper<PrivateDetailTypeErasedT> ref)
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>(ref.get())
{
}
};
const private_detail_te_handle_base_type& private_detail_te_get_handle() const
{
return *private_detail_te_handle_mem_var;
}
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
if(!private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
std::shared_ptr<private_detail_te_handle_base_type> private_detail_te_handle_mem_var;
};
template <typename ValueType>
inline const ValueType* any_cast(const context* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType* any_cast(context* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType& any_cast(context& x)
{
auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
throw std::bad_cast();
return *y;
}
template <typename ValueType>
inline const ValueType& any_cast(const context& x)
{
const auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
throw std::bad_cast();
return *y;
}
} // namespace rtg
#endif
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <utility> #include <utility>
#include <rtg/shape.hpp> #include <rtg/shape.hpp>
#include <rtg/argument.hpp> #include <rtg/argument.hpp>
#include <rtg/context.hpp>
namespace rtg { namespace rtg {
...@@ -28,7 +29,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name()) ...@@ -28,7 +29,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
* { * {
* std::string name() const; * std::string name() const;
* shape compute_shape(std::vector<shape> input) const; * shape compute_shape(std::vector<shape> input) const;
* argument compute(shape output,std::vector<argument> input) const; * argument compute(context& ctx,shape output,std::vector<argument> input) const;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ; * friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* }; * };
* *
...@@ -95,10 +96,11 @@ struct operation ...@@ -95,10 +96,11 @@ struct operation
return (*this).private_detail_te_get_handle().compute_shape(std::move(input)); return (*this).private_detail_te_get_handle().compute_shape(std::move(input));
} }
argument compute(shape output, std::vector<argument> input) const argument compute(context& ctx, shape output, std::vector<argument> input) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().compute(std::move(output), std::move(input)); return (*this).private_detail_te_get_handle().compute(
ctx, std::move(output), std::move(input));
} }
friend std::ostream& operator<<(std::ostream& os, const operation& op) friend std::ostream& operator<<(std::ostream& os, const operation& op)
...@@ -116,7 +118,7 @@ struct operation ...@@ -116,7 +118,7 @@ struct operation
virtual std::string name() const = 0; virtual std::string name() const = 0;
virtual shape compute_shape(std::vector<shape> input) const = 0; virtual shape compute_shape(std::vector<shape> input) const = 0;
virtual argument compute(shape output, std::vector<argument> input) const = 0; virtual argument compute(context& ctx, shape output, std::vector<argument> input) const = 0;
virtual std::ostream& operator_shift_left(std::ostream& os) const = 0; virtual std::ostream& operator_shift_left(std::ostream& os) const = 0;
}; };
...@@ -156,10 +158,10 @@ struct operation ...@@ -156,10 +158,10 @@ struct operation
return private_detail_te_value.compute_shape(std::move(input)); return private_detail_te_value.compute_shape(std::move(input));
} }
argument compute(shape output, std::vector<argument> input) const override argument compute(context& ctx, shape output, std::vector<argument> input) const override
{ {
return private_detail_te_value.compute(std::move(output), std::move(input)); return private_detail_te_value.compute(ctx, std::move(output), std::move(input));
} }
std::ostream& operator_shift_left(std::ostream& os) const override std::ostream& operator_shift_left(std::ostream& os) const override
......
...@@ -83,7 +83,7 @@ struct check_shapes ...@@ -83,7 +83,7 @@ struct check_shapes
struct not_computable struct not_computable
{ {
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(context&, shape, std::vector<argument>) const { RTG_THROW("not computable"); }
}; };
struct convolution struct convolution
...@@ -153,7 +153,7 @@ struct convolution ...@@ -153,7 +153,7 @@ struct convolution
} }
} }
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(context&, shape, std::vector<argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const convolution& op) friend std::ostream& operator<<(std::ostream& os, const convolution& op)
{ {
...@@ -200,7 +200,7 @@ struct pooling ...@@ -200,7 +200,7 @@ struct pooling
}}; }};
} }
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(context&, shape, std::vector<argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const pooling& op) friend std::ostream& operator<<(std::ostream& os, const pooling& op)
{ {
...@@ -223,7 +223,7 @@ struct activation ...@@ -223,7 +223,7 @@ struct activation
return inputs.front(); return inputs.front();
} }
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(context&, shape, std::vector<argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const activation& op) friend std::ostream& operator<<(std::ostream& os, const activation& op)
{ {
os << op.name() << ":" << op.mode; os << op.name() << ":" << op.mode;
...@@ -261,7 +261,7 @@ struct transpose ...@@ -261,7 +261,7 @@ struct transpose
} }
return {t, output_lens, output_strides}; return {t, output_lens, output_strides};
} }
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(context&, shape, std::vector<argument>) const { RTG_THROW("not computable"); }
}; };
struct contiguous struct contiguous
...@@ -278,7 +278,7 @@ struct contiguous ...@@ -278,7 +278,7 @@ struct contiguous
} }
return {t, lens}; return {t, lens};
} }
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(context&, shape, std::vector<argument>) const { RTG_THROW("not computable"); }
}; };
struct reshape struct reshape
...@@ -306,7 +306,7 @@ struct reshape ...@@ -306,7 +306,7 @@ struct reshape
return s; return s;
} }
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(context&, shape, std::vector<argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const reshape& op) friend std::ostream& operator<<(std::ostream& os, const reshape& op)
{ {
...@@ -332,7 +332,7 @@ struct gemm ...@@ -332,7 +332,7 @@ struct gemm
return {t, {a.lens()[0], b.lens()[1]}}; return {t, {a.lens()[0], b.lens()[1]}};
} }
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(context&, shape, std::vector<argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const gemm& op) friend std::ostream& operator<<(std::ostream& os, const gemm& op)
{ {
...@@ -349,7 +349,7 @@ struct unary ...@@ -349,7 +349,7 @@ struct unary
check_shapes{inputs}.has(1); check_shapes{inputs}.has(1);
return inputs.at(0); return inputs.at(0);
} }
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(context&, shape, std::vector<argument>) const { RTG_THROW("not computable"); }
}; };
struct identity : unary struct identity : unary
...@@ -449,7 +449,7 @@ struct broadcast ...@@ -449,7 +449,7 @@ struct broadcast
return {t, result.lens(), std::move(bcast_strides)}; return {t, result.lens(), std::move(bcast_strides)};
} }
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
return {output_shape, std::move(args.at(1).data)}; return {output_shape, std::move(args.at(1).data)};
} }
...@@ -463,7 +463,7 @@ struct binary ...@@ -463,7 +463,7 @@ struct binary
check_shapes{inputs}.has(2).same_type().same_dims(); check_shapes{inputs}.has(2).same_type().same_dims();
return inputs.at(0); return inputs.at(0);
} }
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(context&, shape, std::vector<argument>) const { RTG_THROW("not computable"); }
}; };
struct add : binary struct add : binary
...@@ -495,7 +495,7 @@ struct outline ...@@ -495,7 +495,7 @@ struct outline
check_shapes{inputs}.has(0); check_shapes{inputs}.has(0);
return s; return s;
} }
argument compute(shape, std::vector<argument>) const { return {s, nullptr}; } argument compute(context&, shape, std::vector<argument>) const { return {s, nullptr}; }
}; };
} // namespace rtg } // namespace rtg
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <memory> #include <memory>
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include <rtg/context.hpp>
namespace rtg { namespace rtg {
...@@ -18,6 +19,7 @@ struct program; ...@@ -18,6 +19,7 @@ struct program;
* { * {
* std::string name() const; * std::string name() const;
* void apply(program & p) const; * void apply(program & p) const;
* context get_context() const;
* }; * };
* *
*/ */
...@@ -83,6 +85,12 @@ struct target ...@@ -83,6 +85,12 @@ struct target
return (*this).private_detail_te_get_handle().apply(p); return (*this).private_detail_te_get_handle().apply(p);
} }
context get_context() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().get_context();
}
private: private:
struct private_detail_te_handle_base_type struct private_detail_te_handle_base_type
{ {
...@@ -92,6 +100,7 @@ struct target ...@@ -92,6 +100,7 @@ struct target
virtual std::string name() const = 0; virtual std::string name() const = 0;
virtual void apply(program& p) const = 0; virtual void apply(program& p) const = 0;
virtual context get_context() const = 0;
}; };
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
...@@ -126,6 +135,8 @@ struct target ...@@ -126,6 +135,8 @@ struct target
void apply(program& p) const override { return private_detail_te_value.apply(p); } void apply(program& p) const override { return private_detail_te_value.apply(p); }
context get_context() const override { return private_detail_te_value.get_context(); }
PrivateDetailTypeErasedT private_detail_te_value; PrivateDetailTypeErasedT private_detail_te_value;
}; };
......
...@@ -25,7 +25,7 @@ struct unknown ...@@ -25,7 +25,7 @@ struct unknown
else else
return input.front(); return input.front();
} }
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(context&, shape, std::vector<argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const unknown& x) friend std::ostream& operator<<(std::ostream& os, const unknown& x)
{ {
os << x.name(); os << x.name();
......
...@@ -10,6 +10,7 @@ struct program_impl ...@@ -10,6 +10,7 @@ struct program_impl
{ {
// A list is used to keep references to an instruction stable // A list is used to keep references to an instruction stable
std::list<instruction> instructions; std::list<instruction> instructions;
context ctx;
}; };
const operation& get_operation(instruction_ref ins) { return ins->op; } const operation& get_operation(instruction_ref ins) { return ins->op; }
...@@ -109,6 +110,7 @@ instruction_ref program::validate() const ...@@ -109,6 +110,7 @@ instruction_ref program::validate() const
void program::compile(const target& t) void program::compile(const target& t)
{ {
assert(this->validate() != impl->instructions.end()); assert(this->validate() != impl->instructions.end());
this->impl->ctx = t.get_context();
t.apply(*this); t.apply(*this);
if(this->validate() == impl->instructions.end()) if(this->validate() == impl->instructions.end())
RTG_THROW("Invalid program from compilation"); RTG_THROW("Invalid program from compilation");
...@@ -140,7 +142,7 @@ argument program::eval(std::unordered_map<std::string, argument> params) const ...@@ -140,7 +142,7 @@ argument program::eval(std::unordered_map<std::string, argument> params) const
ins.arguments.end(), ins.arguments.end(),
values.begin(), values.begin(),
[&](instruction_ref i) { return results.at(std::addressof(*i)); }); [&](instruction_ref i) { return results.at(std::addressof(*i)); });
result = ins.op.compute(ins.result, values); result = ins.op.compute(this->impl->ctx, ins.result, values);
} }
results.emplace(std::addressof(ins), result); results.emplace(std::addressof(ins), result);
} }
......
...@@ -20,7 +20,7 @@ struct cpu_convolution ...@@ -20,7 +20,7 @@ struct cpu_convolution
std::string name() const { return "cpu::convolution"; } std::string name() const { return "cpu::convolution"; }
shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); } shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input, auto weights) { visit_all(result, args[0], args[1])([&](auto output, auto input, auto weights) {
...@@ -86,7 +86,7 @@ struct cpu_pooling ...@@ -86,7 +86,7 @@ struct cpu_pooling
std::string name() const { return "cpu::pooling_" + Op::name(); } std::string name() const { return "cpu::pooling_" + Op::name(); }
shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); } shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) { visit_all(result, args[0])([&](auto output, auto input) {
...@@ -134,7 +134,7 @@ struct cpu_transpose ...@@ -134,7 +134,7 @@ struct cpu_transpose
std::string name() const { return "cpu::transpose"; } std::string name() const { return "cpu::transpose"; }
shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); } shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
return {output_shape, std::move(args.front().data)}; return {output_shape, std::move(args.front().data)};
} }
...@@ -145,7 +145,7 @@ struct cpu_contiguous ...@@ -145,7 +145,7 @@ struct cpu_contiguous
contiguous op; contiguous op;
std::string name() const { return "cpu::contiguous"; } std::string name() const { return "cpu::contiguous"; }
shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); } shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) { visit_all(result, args[0])([&](auto output, auto input) {
...@@ -163,7 +163,7 @@ struct cpu_reshape ...@@ -163,7 +163,7 @@ struct cpu_reshape
std::string name() const { return "cpu::reshape"; } std::string name() const { return "cpu::reshape"; }
shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); } shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
return {output_shape, std::move(args.front().data)}; return {output_shape, std::move(args.front().data)};
} }
...@@ -175,7 +175,7 @@ struct cpu_gemm ...@@ -175,7 +175,7 @@ struct cpu_gemm
std::string name() const { return "cpu::gemm"; } std::string name() const { return "cpu::gemm"; }
shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); } shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto cmat, auto amat, auto bmat) { visit_all(result, args[0], args[1])([&](auto cmat, auto amat, auto bmat) {
...@@ -334,7 +334,7 @@ struct cpu_unary ...@@ -334,7 +334,7 @@ struct cpu_unary
Op op; Op op;
std::string name() const { return op.name(); } std::string name() const { return op.name(); }
shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); } shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
result.visit([&](auto output) { result.visit([&](auto output) {
...@@ -350,7 +350,7 @@ struct softmax2d ...@@ -350,7 +350,7 @@ struct softmax2d
{ {
std::string name() const { return "cpu::softmax2d"; } std::string name() const { return "cpu::softmax2d"; }
shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); } shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) { visit_all(result, args[0])([&](auto output, auto input) {
...@@ -426,7 +426,7 @@ struct cpu_binary ...@@ -426,7 +426,7 @@ struct cpu_binary
Op op; Op op;
std::string name() const { return op.name(); } std::string name() const { return op.name(); }
shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); } shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) { visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
......
...@@ -10,6 +10,7 @@ struct cpu_target ...@@ -10,6 +10,7 @@ struct cpu_target
{ {
std::string name() const; std::string name() const;
void apply(program& p) const; void apply(program& p) const;
context get_context() const { return {}; }
}; };
} // namespace cpu } // namespace cpu
......
...@@ -20,7 +20,7 @@ struct hip_allocate ...@@ -20,7 +20,7 @@ struct hip_allocate
check_shapes{inputs}.has(1); check_shapes{inputs}.has(1);
return inputs.front(); return inputs.front();
} }
argument compute(shape output_shape, std::vector<argument>) const argument compute(context&, shape output_shape, std::vector<argument>) const
{ {
return allocate_gpu(output_shape); return allocate_gpu(output_shape);
} }
......
...@@ -10,6 +10,7 @@ struct miopen_target ...@@ -10,6 +10,7 @@ struct miopen_target
{ {
std::string name() const; std::string name() const;
void apply(program& p) const; void apply(program& p) const;
context get_context() const { return {}; }
}; };
} // namespace miopen } // namespace miopen
......
...@@ -21,7 +21,7 @@ struct miopen_convolution ...@@ -21,7 +21,7 @@ struct miopen_convolution
check_shapes{inputs}.has(4); check_shapes{inputs}.has(4);
return op.compute_shape({inputs.at(1), inputs.at(2)}); return op.compute_shape({inputs.at(1), inputs.at(2)});
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
auto x_desc = make_tensor(args[1].get_shape()); auto x_desc = make_tensor(args[1].get_shape());
auto w_desc = make_tensor(args[2].get_shape()); auto w_desc = make_tensor(args[2].get_shape());
...@@ -72,7 +72,7 @@ struct miopen_pooling ...@@ -72,7 +72,7 @@ struct miopen_pooling
check_shapes{inputs}.has(3); check_shapes{inputs}.has(3);
return op.compute_shape({inputs.at(1)}); return op.compute_shape({inputs.at(1)});
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
auto x_desc = make_tensor(args[1].get_shape()); auto x_desc = make_tensor(args[1].get_shape());
auto y_desc = make_tensor(output_shape); auto y_desc = make_tensor(output_shape);
...@@ -104,7 +104,7 @@ struct miopen_add ...@@ -104,7 +104,7 @@ struct miopen_add
return inputs.at(1); return inputs.at(1);
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
if(args[2].get_shape().broadcasted()) if(args[2].get_shape().broadcasted())
{ {
...@@ -150,7 +150,7 @@ struct miopen_gemm ...@@ -150,7 +150,7 @@ struct miopen_gemm
check_shapes{inputs}.has(4); check_shapes{inputs}.has(4);
return op.compute_shape({inputs.at(1), inputs.at(2)}); return op.compute_shape({inputs.at(1), inputs.at(2)});
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
...@@ -175,7 +175,7 @@ struct miopen_relu ...@@ -175,7 +175,7 @@ struct miopen_relu
return inputs.at(1); return inputs.at(1);
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
float alpha = 1, beta = 0; float alpha = 1, beta = 0;
auto x_desc = make_tensor(args[1].get_shape()); auto x_desc = make_tensor(args[1].get_shape());
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
struct sum_op struct sum_op
{ {
std::string name() const { return "sum"; } std::string name() const { return "sum"; }
rtg::argument compute(rtg::shape, std::vector<rtg::argument> args) const rtg::argument compute(rtg::context&, rtg::shape, std::vector<rtg::argument> args) const
{ {
rtg::argument result; rtg::argument result;
if(args.size() != 2) if(args.size() != 2)
...@@ -37,7 +37,7 @@ struct sum_op ...@@ -37,7 +37,7 @@ struct sum_op
struct minus_op struct minus_op
{ {
std::string name() const { return "minus"; } std::string name() const { return "minus"; }
rtg::argument compute(rtg::shape, std::vector<rtg::argument> args) const rtg::argument compute(rtg::context&, rtg::shape, std::vector<rtg::argument> args) const
{ {
rtg::argument result; rtg::argument result;
if(args.size() != 2) if(args.size() != 2)
...@@ -67,6 +67,7 @@ struct id_target ...@@ -67,6 +67,7 @@ struct id_target
{ {
std::string name() const { return "id"; } std::string name() const { return "id"; }
void apply(rtg::program&) const {} void apply(rtg::program&) const {}
rtg::context get_context() const { return {}; }
}; };
void literal_test1() void literal_test1()
......
...@@ -9,7 +9,7 @@ struct simple_operation ...@@ -9,7 +9,7 @@ struct simple_operation
int data = 1; int data = 1;
std::string name() const { return "simple"; } std::string name() const { return "simple"; }
rtg::shape compute_shape(std::vector<rtg::shape>) const { RTG_THROW("not computable"); } rtg::shape compute_shape(std::vector<rtg::shape>) const { RTG_THROW("not computable"); }
rtg::argument compute(rtg::shape, std::vector<rtg::argument>) const rtg::argument compute(rtg::context&, rtg::shape, std::vector<rtg::argument>) const
{ {
RTG_THROW("not computable"); RTG_THROW("not computable");
} }
...@@ -24,7 +24,7 @@ struct simple_operation_no_print ...@@ -24,7 +24,7 @@ struct simple_operation_no_print
{ {
std::string name() const { return "simple"; } std::string name() const { return "simple"; }
rtg::shape compute_shape(std::vector<rtg::shape>) const { RTG_THROW("not computable"); } rtg::shape compute_shape(std::vector<rtg::shape>) const { RTG_THROW("not computable"); }
rtg::argument compute(rtg::shape, std::vector<rtg::argument>) const rtg::argument compute(rtg::context&, rtg::shape, std::vector<rtg::argument>) const
{ {
RTG_THROW("not computable"); RTG_THROW("not computable");
} }
......
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
ls -1 $DIR/include/ | xargs -n 1 -P $(nproc) -I{} -t bash -c "python $DIR/te.py $DIR/include/{} | clang-format-5.0 -style=file > $DIR/../src/include/rtg/{}" ls -1 $DIR/include/ | xargs -n 1 -P $(nproc) -I{} -t bash -c "python3.6 $DIR/te.py $DIR/include/{} | clang-format-5.0 -style=file > $DIR/../src/include/rtg/{}"
#ifndef RTG_GUARD_CONTEXT_HPP
#define RTG_GUARD_CONTEXT_HPP
namespace rtg {
<%
interface('context')
%>
} // namespace rtg
#endif
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <utility> #include <utility>
#include <rtg/shape.hpp> #include <rtg/shape.hpp>
#include <rtg/argument.hpp> #include <rtg/argument.hpp>
#include <rtg/context.hpp>
namespace rtg { namespace rtg {
...@@ -25,7 +26,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name()) ...@@ -25,7 +26,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
interface('operation', interface('operation',
virtual('name', returns='std::string', const=True), virtual('name', returns='std::string', const=True),
virtual('compute_shape', returns='shape', input='std::vector<shape>', const=True), virtual('compute_shape', returns='shape', input='std::vector<shape>', const=True),
virtual('compute', returns='argument', output='shape', input='std::vector<argument>', const=True), virtual('compute', returns='argument', ctx='context&', output='shape', input='std::vector<argument>', const=True),
friend('operator<<', returns='std::ostream &', os='std::ostream &', op='const operation &', using='rtg::operation_stream::operator<<') friend('operator<<', returns='std::ostream &', os='std::ostream &', op='const operation &', using='rtg::operation_stream::operator<<')
) )
%> %>
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <memory> #include <memory>
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include <rtg/context.hpp>
namespace rtg { namespace rtg {
...@@ -14,7 +15,8 @@ struct program; ...@@ -14,7 +15,8 @@ struct program;
<% <%
interface('target', interface('target',
virtual('name', returns='std::string', const=True), virtual('name', returns='std::string', const=True),
virtual('apply', returns='void', p='program &', const=True) virtual('apply', returns='void', p='program &', const=True),
virtual('get_context', returns='context', const=True)
) )
%> %>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment