Unverified Commit 5f1ea74f authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge pull request #9 from ROCmSoftwarePlatform/context

Add context object for execution
parents 7591b7ff a7dcd9fb
pfultz2/rocm-recipes pfultz2/rocm-recipes
danmar/cppcheck@d9f9bdda7344e80585f71141be7797055d7987f3 danmar/cppcheck@d9f9bdda7344e80585f71141be7797055d7987f3
# python/cpython@v3.6.6 -X autotools -H sha256:92aa914572c695c0aeb01b0a214813f414da4b51a371234df514a74761f2bb36
-f requirements.txt -f requirements.txt
#ifndef RTG_GUARD_BUILTIN_HPP #ifndef RTG_GUARD_BUILTIN_HPP
#define RTG_GUARD_BUILTIN_HPP #define RTG_GUARD_BUILTIN_HPP
#include <rtg/operation.hpp> #include <rtg/context.hpp>
#include <rtg/errors.hpp> #include <rtg/errors.hpp>
namespace rtg { namespace rtg {
...@@ -12,7 +12,7 @@ struct literal ...@@ -12,7 +12,7 @@ struct literal
{ {
std::string name() const { return "@literal"; } std::string name() const { return "@literal"; }
shape compute_shape(std::vector<shape>) const { RTG_THROW("builtin"); } shape compute_shape(std::vector<shape>) const { RTG_THROW("builtin"); }
argument compute(shape, std::vector<argument>) const { RTG_THROW("builtin"); } argument compute(context&, shape, std::vector<argument>) const { RTG_THROW("builtin"); }
}; };
struct outline struct outline
...@@ -20,7 +20,7 @@ struct outline ...@@ -20,7 +20,7 @@ struct outline
shape s; shape s;
std::string name() const { return "@outline"; } std::string name() const { return "@outline"; }
shape compute_shape(std::vector<shape>) const { RTG_THROW("builtin"); } shape compute_shape(std::vector<shape>) const { RTG_THROW("builtin"); }
argument compute(shape, std::vector<argument>) const { RTG_THROW("builtin"); } argument compute(context&, shape, std::vector<argument>) const { RTG_THROW("builtin"); }
}; };
struct param struct param
...@@ -28,7 +28,7 @@ struct param ...@@ -28,7 +28,7 @@ struct param
std::string parameter; std::string parameter;
std::string name() const { return "@param"; } std::string name() const { return "@param"; }
shape compute_shape(std::vector<shape>) const { RTG_THROW("builtin"); } shape compute_shape(std::vector<shape>) const { RTG_THROW("builtin"); }
argument compute(shape, std::vector<argument>) const { RTG_THROW("builtin"); } argument compute(context&, shape, std::vector<argument>) const { RTG_THROW("builtin"); }
friend std::ostream& operator<<(std::ostream& os, const param& op) friend std::ostream& operator<<(std::ostream& os, const param& op)
{ {
os << op.name() << ":" << op.parameter; os << op.name() << ":" << op.parameter;
......
#ifndef RTG_GUARD_CONTEXT_HPP
#define RTG_GUARD_CONTEXT_HPP
namespace rtg {
/*
* Type-erased interface for:
*
* struct context
* {
* };
*
*/
struct context
{
// Constructors
context() = default;
template <typename PrivateDetailTypeErasedT>
context(PrivateDetailTypeErasedT value)
: private_detail_te_handle_mem_var(
std::make_shared<private_detail_te_handle_type<
typename std::remove_reference<PrivateDetailTypeErasedT>::type>>(
std::forward<PrivateDetailTypeErasedT>(value)))
{
}
// Assignment
template <typename PrivateDetailTypeErasedT>
context& operator=(PrivateDetailTypeErasedT value)
{
if(private_detail_te_handle_mem_var.unique())
*private_detail_te_handle_mem_var = std::forward<PrivateDetailTypeErasedT>(value);
else if(!private_detail_te_handle_mem_var)
private_detail_te_handle_mem_var = std::make_shared<PrivateDetailTypeErasedT>(
std::forward<PrivateDetailTypeErasedT>(value));
return *this;
}
// Cast
template <typename PrivateDetailTypeErasedT>
PrivateDetailTypeErasedT* any_cast()
{
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
.private_detail_te_value)
: nullptr;
}
template <typename PrivateDetailTypeErasedT>
const typename std::remove_cv<PrivateDetailTypeErasedT>::type* any_cast() const
{
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<const private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
.private_detail_te_value)
: nullptr;
}
const std::type_info& type_id() const
{
if(private_detail_te_handle_empty())
return typeid(std::nullptr_t);
else
return private_detail_te_get_handle().type();
}
private:
struct private_detail_te_handle_base_type
{
virtual ~private_detail_te_handle_base_type() {}
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0;
};
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type
{
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<std::is_reference<PrivateDetailTypeErasedU>::value>::type* =
nullptr)
: private_detail_te_value(value)
{
}
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value))
{
}
std::shared_ptr<private_detail_te_handle_base_type> clone() const override
{
return std::make_shared<private_detail_te_handle_type>(private_detail_te_value);
}
const std::type_info& type() const override { return typeid(private_detail_te_value); }
PrivateDetailTypeErasedT private_detail_te_value;
};
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type<std::reference_wrapper<PrivateDetailTypeErasedT>>
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>
{
private_detail_te_handle_type(std::reference_wrapper<PrivateDetailTypeErasedT> ref)
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>(ref.get())
{
}
};
bool private_detail_te_handle_empty() const
{
return private_detail_te_handle_mem_var == nullptr;
}
const private_detail_te_handle_base_type& private_detail_te_get_handle() const
{
assert(private_detail_te_handle_mem_var != nullptr);
return *private_detail_te_handle_mem_var;
}
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
std::shared_ptr<private_detail_te_handle_base_type> private_detail_te_handle_mem_var;
};
template <typename ValueType>
inline const ValueType* any_cast(const context* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType* any_cast(context* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType& any_cast(context& x)
{
auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
throw std::bad_cast();
return *y;
}
template <typename ValueType>
inline const ValueType& any_cast(const context& x)
{
const auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
throw std::bad_cast();
return *y;
}
} // namespace rtg
#endif
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <utility> #include <utility>
#include <rtg/shape.hpp> #include <rtg/shape.hpp>
#include <rtg/argument.hpp> #include <rtg/argument.hpp>
#include <rtg/context.hpp>
namespace rtg { namespace rtg {
...@@ -28,7 +29,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name()) ...@@ -28,7 +29,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
* { * {
* std::string name() const; * std::string name() const;
* shape compute_shape(std::vector<shape> input) const; * shape compute_shape(std::vector<shape> input) const;
* argument compute(shape output,std::vector<argument> input) const; * argument compute(context& ctx,shape output,std::vector<argument> input) const;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ; * friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* }; * };
* *
...@@ -83,6 +84,14 @@ struct operation ...@@ -83,6 +84,14 @@ struct operation
: nullptr; : nullptr;
} }
const std::type_info& type_id() const
{
if(private_detail_te_handle_empty())
return typeid(std::nullptr_t);
else
return private_detail_te_get_handle().type();
}
std::string name() const std::string name() const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
...@@ -95,10 +104,11 @@ struct operation ...@@ -95,10 +104,11 @@ struct operation
return (*this).private_detail_te_get_handle().compute_shape(std::move(input)); return (*this).private_detail_te_get_handle().compute_shape(std::move(input));
} }
argument compute(shape output, std::vector<argument> input) const argument compute(context& ctx, shape output, std::vector<argument> input) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().compute(std::move(output), std::move(input)); return (*this).private_detail_te_get_handle().compute(
ctx, std::move(output), std::move(input));
} }
friend std::ostream& operator<<(std::ostream& os, const operation& op) friend std::ostream& operator<<(std::ostream& os, const operation& op)
...@@ -114,10 +124,10 @@ struct operation ...@@ -114,10 +124,10 @@ struct operation
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0; virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0; virtual const std::type_info& type() const = 0;
virtual std::string name() const = 0; virtual std::string name() const = 0;
virtual shape compute_shape(std::vector<shape> input) const = 0; virtual shape compute_shape(std::vector<shape> input) const = 0;
virtual argument compute(shape output, std::vector<argument> input) const = 0; virtual argument compute(context& ctx, shape output, std::vector<argument> input) const = 0;
virtual std::ostream& operator_shift_left(std::ostream& os) const = 0; virtual std::ostream& operator_shift_left(std::ostream& os) const = 0;
}; };
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
...@@ -156,10 +166,10 @@ struct operation ...@@ -156,10 +166,10 @@ struct operation
return private_detail_te_value.compute_shape(std::move(input)); return private_detail_te_value.compute_shape(std::move(input));
} }
argument compute(shape output, std::vector<argument> input) const override argument compute(context& ctx, shape output, std::vector<argument> input) const override
{ {
return private_detail_te_value.compute(std::move(output), std::move(input)); return private_detail_te_value.compute(ctx, std::move(output), std::move(input));
} }
std::ostream& operator_shift_left(std::ostream& os) const override std::ostream& operator_shift_left(std::ostream& os) const override
...@@ -181,13 +191,20 @@ struct operation ...@@ -181,13 +191,20 @@ struct operation
} }
}; };
bool private_detail_te_handle_empty() const
{
return private_detail_te_handle_mem_var == nullptr;
}
const private_detail_te_handle_base_type& private_detail_te_get_handle() const const private_detail_te_handle_base_type& private_detail_te_get_handle() const
{ {
assert(private_detail_te_handle_mem_var != nullptr);
return *private_detail_te_handle_mem_var; return *private_detail_te_handle_mem_var;
} }
private_detail_te_handle_base_type& private_detail_te_get_handle() private_detail_te_handle_base_type& private_detail_te_get_handle()
{ {
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique()) if(!private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone(); private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var; return *private_detail_te_handle_mem_var;
......
...@@ -12,15 +12,29 @@ namespace rtg { ...@@ -12,15 +12,29 @@ namespace rtg {
struct check_shapes struct check_shapes
{ {
const std::vector<shape>* shapes; const std::vector<shape>* shapes;
const std::string name;
check_shapes(const std::vector<shape>& s) : shapes(&s) {} check_shapes(const std::vector<shape>& s) : shapes(&s) {}
template <class Op>
check_shapes(const std::vector<shape>& s, const Op& op) : shapes(&s), name(op.name())
{
}
std::string prefix() const
{
if(name.empty())
return "";
else
return name + ": ";
}
const check_shapes& has(std::size_t n) const const check_shapes& has(std::size_t n) const
{ {
assert(shapes != nullptr); assert(shapes != nullptr);
if(shapes->size() != n) if(shapes->size() != n)
RTG_THROW("Wrong number of arguments: expected " + std::to_string(n) + " but given " + RTG_THROW(prefix() + "Wrong number of arguments: expected " + std::to_string(n) +
std::to_string(shapes->size())); " but given " + std::to_string(shapes->size()));
return *this; return *this;
} }
...@@ -30,7 +44,7 @@ struct check_shapes ...@@ -30,7 +44,7 @@ struct check_shapes
if(!shapes->empty()) if(!shapes->empty())
{ {
if(shapes->front().lens().size() != n) if(shapes->front().lens().size() != n)
RTG_THROW("Only " + std::to_string(n) + "d supported"); RTG_THROW(prefix() + "Only " + std::to_string(n) + "d supported");
} }
return *this; return *this;
} }
...@@ -38,28 +52,28 @@ struct check_shapes ...@@ -38,28 +52,28 @@ struct check_shapes
const check_shapes& same_shape() const const check_shapes& same_shape() const
{ {
if(!this->same([](const shape& s) { return s; })) if(!this->same([](const shape& s) { return s; }))
RTG_THROW("Shapes do not match"); RTG_THROW(prefix() + "Shapes do not match");
return *this; return *this;
} }
const check_shapes& same_type() const const check_shapes& same_type() const
{ {
if(!this->same([](const shape& s) { return s.type(); })) if(!this->same([](const shape& s) { return s.type(); }))
RTG_THROW("Types do not match"); RTG_THROW(prefix() + "Types do not match");
return *this; return *this;
} }
const check_shapes& same_dims() const const check_shapes& same_dims() const
{ {
if(!this->same([](const shape& s) { return s.lens(); })) if(!this->same([](const shape& s) { return s.lens(); }))
RTG_THROW("Dimensions do not match"); RTG_THROW(prefix() + "Dimensions do not match");
return *this; return *this;
} }
const check_shapes& same_ndims() const const check_shapes& same_ndims() const
{ {
if(!this->same([](const shape& s) { return s.lens().size(); })) if(!this->same([](const shape& s) { return s.lens().size(); }))
RTG_THROW("Dimensions do not match"); RTG_THROW(prefix() + "Dimensions do not match");
return *this; return *this;
} }
...@@ -83,7 +97,7 @@ struct check_shapes ...@@ -83,7 +97,7 @@ struct check_shapes
struct not_computable struct not_computable
{ {
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(context&, shape, std::vector<argument>) const { RTG_THROW("not computable"); }
}; };
struct convolution struct convolution
...@@ -101,7 +115,7 @@ struct convolution ...@@ -101,7 +115,7 @@ struct convolution
std::string name() const { return "convolution"; } std::string name() const { return "convolution"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(2).same_type().same_ndims().only_dims(4); check_shapes{inputs, *this}.has(2).same_type().same_ndims().only_dims(4);
const shape& input = inputs.at(0); const shape& input = inputs.at(0);
const shape& weights = inputs.at(1); const shape& weights = inputs.at(1);
...@@ -153,7 +167,7 @@ struct convolution ...@@ -153,7 +167,7 @@ struct convolution
} }
} }
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(context&, shape, std::vector<argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const convolution& op) friend std::ostream& operator<<(std::ostream& os, const convolution& op)
{ {
...@@ -175,7 +189,7 @@ struct pooling ...@@ -175,7 +189,7 @@ struct pooling
std::string name() const { return "pooling"; } std::string name() const { return "pooling"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1).only_dims(4); check_shapes{inputs, *this}.has(1).only_dims(4);
const shape& input = inputs.at(0); const shape& input = inputs.at(0);
auto t = input.type(); auto t = input.type();
...@@ -200,7 +214,7 @@ struct pooling ...@@ -200,7 +214,7 @@ struct pooling
}}; }};
} }
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(context&, shape, std::vector<argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const pooling& op) friend std::ostream& operator<<(std::ostream& os, const pooling& op)
{ {
...@@ -219,11 +233,11 @@ struct activation ...@@ -219,11 +233,11 @@ struct activation
std::string name() const { return "activation"; } std::string name() const { return "activation"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1); check_shapes{inputs, *this}.has(1);
return inputs.front(); return inputs.front();
} }
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(context&, shape, std::vector<argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const activation& op) friend std::ostream& operator<<(std::ostream& os, const activation& op)
{ {
os << op.name() << ":" << op.mode; os << op.name() << ":" << op.mode;
...@@ -237,7 +251,7 @@ struct transpose ...@@ -237,7 +251,7 @@ struct transpose
std::string name() const { return "transpose"; } std::string name() const { return "transpose"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1); check_shapes{inputs, *this}.has(1);
auto input = inputs.at(0); auto input = inputs.at(0);
auto input_lens = input.lens(); auto input_lens = input.lens();
auto input_strides = input.strides(); auto input_strides = input.strides();
...@@ -261,7 +275,7 @@ struct transpose ...@@ -261,7 +275,7 @@ struct transpose
} }
return {t, output_lens, output_strides}; return {t, output_lens, output_strides};
} }
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(context&, shape, std::vector<argument>) const { RTG_THROW("not computable"); }
}; };
struct contiguous struct contiguous
...@@ -269,7 +283,7 @@ struct contiguous ...@@ -269,7 +283,7 @@ struct contiguous
std::string name() const { return "contiguous"; } std::string name() const { return "contiguous"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1); check_shapes{inputs, *this}.has(1);
auto lens = inputs.at(0).lens(); auto lens = inputs.at(0).lens();
auto t = inputs.at(0).type(); auto t = inputs.at(0).type();
if(lens.size() < 2) if(lens.size() < 2)
...@@ -278,7 +292,7 @@ struct contiguous ...@@ -278,7 +292,7 @@ struct contiguous
} }
return {t, lens}; return {t, lens};
} }
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(context&, shape, std::vector<argument>) const { RTG_THROW("not computable"); }
}; };
struct reshape struct reshape
...@@ -287,7 +301,7 @@ struct reshape ...@@ -287,7 +301,7 @@ struct reshape
std::string name() const { return "reshape"; } std::string name() const { return "reshape"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1); check_shapes{inputs, *this}.has(1);
auto&& idims = inputs.front().lens(); auto&& idims = inputs.front().lens();
std::vector<std::size_t> rdims(dims.begin(), dims.end()); std::vector<std::size_t> rdims(dims.begin(), dims.end());
for(std::size_t i = 0; i < dims.size(); i++) for(std::size_t i = 0; i < dims.size(); i++)
...@@ -306,7 +320,7 @@ struct reshape ...@@ -306,7 +320,7 @@ struct reshape
return s; return s;
} }
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(context&, shape, std::vector<argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const reshape& op) friend std::ostream& operator<<(std::ostream& os, const reshape& op)
{ {
...@@ -322,7 +336,7 @@ struct gemm ...@@ -322,7 +336,7 @@ struct gemm
std::string name() const { return "gemm"; } std::string name() const { return "gemm"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(2).same_type(); check_shapes{inputs, *this}.has(2).same_type();
const shape& a = inputs.at(0); const shape& a = inputs.at(0);
const shape& b = inputs.at(1); const shape& b = inputs.at(1);
auto t = a.type(); auto t = a.type();
...@@ -332,7 +346,7 @@ struct gemm ...@@ -332,7 +346,7 @@ struct gemm
return {t, {a.lens()[0], b.lens()[1]}}; return {t, {a.lens()[0], b.lens()[1]}};
} }
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(context&, shape, std::vector<argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const gemm& op) friend std::ostream& operator<<(std::ostream& os, const gemm& op)
{ {
...@@ -349,7 +363,7 @@ struct unary ...@@ -349,7 +363,7 @@ struct unary
check_shapes{inputs}.has(1); check_shapes{inputs}.has(1);
return inputs.at(0); return inputs.at(0);
} }
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(context&, shape, std::vector<argument>) const { RTG_THROW("not computable"); }
}; };
struct identity : unary struct identity : unary
...@@ -449,7 +463,7 @@ struct broadcast ...@@ -449,7 +463,7 @@ struct broadcast
return {t, result.lens(), std::move(bcast_strides)}; return {t, result.lens(), std::move(bcast_strides)};
} }
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
return {output_shape, std::move(args.at(1).data)}; return {output_shape, std::move(args.at(1).data)};
} }
...@@ -463,7 +477,7 @@ struct binary ...@@ -463,7 +477,7 @@ struct binary
check_shapes{inputs}.has(2).same_type().same_dims(); check_shapes{inputs}.has(2).same_type().same_dims();
return inputs.at(0); return inputs.at(0);
} }
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(context&, shape, std::vector<argument>) const { RTG_THROW("not computable"); }
}; };
struct add : binary struct add : binary
...@@ -492,10 +506,24 @@ struct outline ...@@ -492,10 +506,24 @@ struct outline
std::string name() const { return "outline"; } std::string name() const { return "outline"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(0); check_shapes{inputs, *this}.has(0);
return s; return s;
} }
argument compute(shape, std::vector<argument>) const { return {s, nullptr}; } argument compute(context&, shape, std::vector<argument>) const { return {s, nullptr}; }
};
template <class T>
struct check_context
{
std::string name() const { return "check_context"; }
shape compute_shape(std::vector<shape>) const { return {}; }
argument compute(context& ctx, shape, std::vector<argument>) const
{
T* x = any_cast<T>(&ctx);
if(x == nullptr)
RTG_THROW(std::string("Unexpected context type: ") + ctx.type_id().name());
return {};
}
}; };
} // namespace rtg } // namespace rtg
......
...@@ -89,13 +89,27 @@ struct raw_data : raw_data_base ...@@ -89,13 +89,27 @@ struct raw_data : raw_data_base
assert(self->single()); assert(self->single());
return self->template at<T>(); return self->template at<T>();
} }
template <class T>
using is_data_ptr =
bool_c<(std::is_void<T>{} or std::is_same<char, std::remove_cv_t<T>>{} or
std::is_same<unsigned char, std::remove_cv_t<T>>{})>;
template <class T>
using get_data_type = std::conditional_t<is_data_ptr<T>{}, float, T>;
template <class T>
bool matches() const
{
return is_data_ptr<T>{} ||
self->get_shape().type() == rtg::shape::get_type<get_data_type<T>>{};
}
template <class T> template <class T>
operator T*() operator T*()
{ {
using type = std::remove_cv_t<T>; using type = std::remove_cv_t<T>;
assert((std::is_void<T>{} or std::is_same<char, type>{} or assert(matches<T>());
std::is_same<unsigned char, type>{} or
self->get_shape().type() == rtg::shape::get_type<T>{}));
return reinterpret_cast<type*>(self->data()); return reinterpret_cast<type*>(self->data());
} }
}; };
......
...@@ -10,10 +10,15 @@ struct and_ : std::is_same<and_<Bs...>, and_<(Bs || true)...>> // NOLINT ...@@ -10,10 +10,15 @@ struct and_ : std::is_same<and_<Bs...>, and_<(Bs || true)...>> // NOLINT
{ {
}; };
template <bool B>
using bool_c = std::integral_constant<bool, B>;
#ifdef CPPCHECK #ifdef CPPCHECK
#define RTG_REQUIRES(...) class = void #define RTG_REQUIRES(...) class = void
#else #else
#define RTG_REQUIRES(...) class = typename std::enable_if<and_<__VA_ARGS__, true>{}>::type #define RTG_REQUIRES(...) \
bool PrivateRequires##__LINE__ = true, \
class = typename std::enable_if<and_<__VA_ARGS__, PrivateRequires##__LINE__>{}>::type
#endif #endif
} // namespace rtg } // namespace rtg
......
...@@ -31,15 +31,12 @@ struct shape ...@@ -31,15 +31,12 @@ struct shape
#define RTG_SHAPE_ENUM_TYPES(x, t) x, #define RTG_SHAPE_ENUM_TYPES(x, t) x,
enum type_t enum type_t
{ {
any_type,
RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_ENUM_TYPES) RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_ENUM_TYPES)
}; };
#undef RTG_SHAPE_ENUM_TYPES #undef RTG_SHAPE_ENUM_TYPES
template <class T, class = void> template <class T, class = void>
struct get_type : std::integral_constant<type_t, any_type> struct get_type;
{
};
#define RTG_SHAPE_GET_TYPE(x, t) \ #define RTG_SHAPE_GET_TYPE(x, t) \
template <class T> \ template <class T> \
struct get_type<t, T> : std::integral_constant<type_t, x> \ struct get_type<t, T> : std::integral_constant<type_t, x> \
...@@ -125,7 +122,6 @@ struct shape ...@@ -125,7 +122,6 @@ struct shape
{ {
switch(this->m_type) switch(this->m_type)
{ {
case any_type: RTG_THROW("Cannot visit the any_type");
#define RTG_SHAPE_VISITOR_CASE(x, t) \ #define RTG_SHAPE_VISITOR_CASE(x, t) \
case x: v(as<t>()); return; case x: v(as<t>()); return;
RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_VISITOR_CASE) RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_VISITOR_CASE)
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <memory> #include <memory>
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include <rtg/context.hpp>
namespace rtg { namespace rtg {
...@@ -18,6 +19,7 @@ struct program; ...@@ -18,6 +19,7 @@ struct program;
* { * {
* std::string name() const; * std::string name() const;
* void apply(program & p) const; * void apply(program & p) const;
* context get_context() const;
* }; * };
* *
*/ */
...@@ -71,6 +73,14 @@ struct target ...@@ -71,6 +73,14 @@ struct target
: nullptr; : nullptr;
} }
const std::type_info& type_id() const
{
if(private_detail_te_handle_empty())
return typeid(std::nullptr_t);
else
return private_detail_te_get_handle().type();
}
std::string name() const std::string name() const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
...@@ -83,6 +93,12 @@ struct target ...@@ -83,6 +93,12 @@ struct target
return (*this).private_detail_te_get_handle().apply(p); return (*this).private_detail_te_get_handle().apply(p);
} }
context get_context() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().get_context();
}
private: private:
struct private_detail_te_handle_base_type struct private_detail_te_handle_base_type
{ {
...@@ -92,6 +108,7 @@ struct target ...@@ -92,6 +108,7 @@ struct target
virtual std::string name() const = 0; virtual std::string name() const = 0;
virtual void apply(program& p) const = 0; virtual void apply(program& p) const = 0;
virtual context get_context() const = 0;
}; };
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
...@@ -126,6 +143,8 @@ struct target ...@@ -126,6 +143,8 @@ struct target
void apply(program& p) const override { return private_detail_te_value.apply(p); } void apply(program& p) const override { return private_detail_te_value.apply(p); }
context get_context() const override { return private_detail_te_value.get_context(); }
PrivateDetailTypeErasedT private_detail_te_value; PrivateDetailTypeErasedT private_detail_te_value;
}; };
...@@ -139,13 +158,20 @@ struct target ...@@ -139,13 +158,20 @@ struct target
} }
}; };
bool private_detail_te_handle_empty() const
{
return private_detail_te_handle_mem_var == nullptr;
}
const private_detail_te_handle_base_type& private_detail_te_get_handle() const const private_detail_te_handle_base_type& private_detail_te_get_handle() const
{ {
assert(private_detail_te_handle_mem_var != nullptr);
return *private_detail_te_handle_mem_var; return *private_detail_te_handle_mem_var;
} }
private_detail_te_handle_base_type& private_detail_te_get_handle() private_detail_te_handle_base_type& private_detail_te_get_handle()
{ {
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique()) if(!private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone(); private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var; return *private_detail_te_handle_mem_var;
......
...@@ -25,7 +25,7 @@ struct unknown ...@@ -25,7 +25,7 @@ struct unknown
else else
return input.front(); return input.front();
} }
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(context&, shape, std::vector<argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const unknown& x) friend std::ostream& operator<<(std::ostream& os, const unknown& x)
{ {
os << x.name(); os << x.name();
......
...@@ -29,8 +29,7 @@ rtg::argument run_gpu(std::string file) ...@@ -29,8 +29,7 @@ rtg::argument run_gpu(std::string file)
auto output = rtg::miopen::to_gpu(rtg::generate_argument(p.get_parameter_shape("output"))); auto output = rtg::miopen::to_gpu(rtg::generate_argument(p.get_parameter_shape("output")));
auto handle = rtg::miopen::make_obj<rtg::miopen::miopen_handle>(&miopenCreate); auto handle = rtg::miopen::make_obj<rtg::miopen::miopen_handle>(&miopenCreate);
auto out = p.eval( auto out = p.eval({{"Input3", input3}, {"output", output}});
{{"Input3", input3}, {"handle", {rtg::shape::any_type, handle.get()}}, {"output", output}});
std::cout << p << std::endl; std::cout << p << std::endl;
return rtg::miopen::from_gpu(out); return rtg::miopen::from_gpu(out);
} }
......
...@@ -10,6 +10,7 @@ struct program_impl ...@@ -10,6 +10,7 @@ struct program_impl
{ {
// A list is used to keep references to an instruction stable // A list is used to keep references to an instruction stable
std::list<instruction> instructions; std::list<instruction> instructions;
context ctx;
}; };
const operation& get_operation(instruction_ref ins) { return ins->op; } const operation& get_operation(instruction_ref ins) { return ins->op; }
...@@ -109,6 +110,7 @@ instruction_ref program::validate() const ...@@ -109,6 +110,7 @@ instruction_ref program::validate() const
void program::compile(const target& t) void program::compile(const target& t)
{ {
assert(this->validate() != impl->instructions.end()); assert(this->validate() != impl->instructions.end());
this->impl->ctx = t.get_context();
t.apply(*this); t.apply(*this);
if(this->validate() == impl->instructions.end()) if(this->validate() == impl->instructions.end())
RTG_THROW("Invalid program from compilation"); RTG_THROW("Invalid program from compilation");
...@@ -140,7 +142,7 @@ argument program::eval(std::unordered_map<std::string, argument> params) const ...@@ -140,7 +142,7 @@ argument program::eval(std::unordered_map<std::string, argument> params) const
ins.arguments.end(), ins.arguments.end(),
values.begin(), values.begin(),
[&](instruction_ref i) { return results.at(std::addressof(*i)); }); [&](instruction_ref i) { return results.at(std::addressof(*i)); });
result = ins.op.compute(ins.result, values); result = ins.op.compute(this->impl->ctx, ins.result, values);
} }
results.emplace(std::addressof(ins), result); results.emplace(std::addressof(ins), result);
} }
......
...@@ -109,7 +109,6 @@ std::string shape::type_string() const ...@@ -109,7 +109,6 @@ std::string shape::type_string() const
{ {
switch(this->m_type) switch(this->m_type)
{ {
case any_type: return "any";
#define RTG_SHAPE_TYPE_STRING_CASE(x, t) \ #define RTG_SHAPE_TYPE_STRING_CASE(x, t) \
case x: return #x; case x: return #x;
RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_TYPE_STRING_CASE) RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_TYPE_STRING_CASE)
......
...@@ -20,7 +20,7 @@ struct cpu_convolution ...@@ -20,7 +20,7 @@ struct cpu_convolution
std::string name() const { return "cpu::convolution"; } std::string name() const { return "cpu::convolution"; }
shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); } shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input, auto weights) { visit_all(result, args[0], args[1])([&](auto output, auto input, auto weights) {
...@@ -86,7 +86,7 @@ struct cpu_pooling ...@@ -86,7 +86,7 @@ struct cpu_pooling
std::string name() const { return "cpu::pooling_" + Op::name(); } std::string name() const { return "cpu::pooling_" + Op::name(); }
shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); } shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) { visit_all(result, args[0])([&](auto output, auto input) {
...@@ -134,7 +134,7 @@ struct cpu_transpose ...@@ -134,7 +134,7 @@ struct cpu_transpose
std::string name() const { return "cpu::transpose"; } std::string name() const { return "cpu::transpose"; }
shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); } shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
return {output_shape, std::move(args.front().data)}; return {output_shape, std::move(args.front().data)};
} }
...@@ -145,7 +145,7 @@ struct cpu_contiguous ...@@ -145,7 +145,7 @@ struct cpu_contiguous
contiguous op; contiguous op;
std::string name() const { return "cpu::contiguous"; } std::string name() const { return "cpu::contiguous"; }
shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); } shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) { visit_all(result, args[0])([&](auto output, auto input) {
...@@ -163,7 +163,7 @@ struct cpu_reshape ...@@ -163,7 +163,7 @@ struct cpu_reshape
std::string name() const { return "cpu::reshape"; } std::string name() const { return "cpu::reshape"; }
shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); } shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
return {output_shape, std::move(args.front().data)}; return {output_shape, std::move(args.front().data)};
} }
...@@ -175,7 +175,7 @@ struct cpu_gemm ...@@ -175,7 +175,7 @@ struct cpu_gemm
std::string name() const { return "cpu::gemm"; } std::string name() const { return "cpu::gemm"; }
shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); } shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto cmat, auto amat, auto bmat) { visit_all(result, args[0], args[1])([&](auto cmat, auto amat, auto bmat) {
...@@ -334,7 +334,7 @@ struct cpu_unary ...@@ -334,7 +334,7 @@ struct cpu_unary
Op op; Op op;
std::string name() const { return op.name(); } std::string name() const { return op.name(); }
shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); } shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
result.visit([&](auto output) { result.visit([&](auto output) {
...@@ -350,7 +350,7 @@ struct softmax2d ...@@ -350,7 +350,7 @@ struct softmax2d
{ {
std::string name() const { return "cpu::softmax2d"; } std::string name() const { return "cpu::softmax2d"; }
shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); } shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) { visit_all(result, args[0])([&](auto output, auto input) {
...@@ -426,7 +426,7 @@ struct cpu_binary ...@@ -426,7 +426,7 @@ struct cpu_binary
Op op; Op op;
std::string name() const { return op.name(); } std::string name() const { return op.name(); }
shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); } shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) { visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
......
...@@ -10,6 +10,7 @@ struct cpu_target ...@@ -10,6 +10,7 @@ struct cpu_target
{ {
std::string name() const; std::string name() const;
void apply(program& p) const; void apply(program& p) const;
context get_context() const { return {}; }
}; };
} // namespace cpu } // namespace cpu
......
...@@ -20,7 +20,7 @@ struct hip_allocate ...@@ -20,7 +20,7 @@ struct hip_allocate
check_shapes{inputs}.has(1); check_shapes{inputs}.has(1);
return inputs.front(); return inputs.front();
} }
argument compute(shape output_shape, std::vector<argument>) const argument compute(context&, shape output_shape, std::vector<argument>) const
{ {
return allocate_gpu(output_shape); return allocate_gpu(output_shape);
} }
......
...@@ -10,6 +10,7 @@ struct miopen_target ...@@ -10,6 +10,7 @@ struct miopen_target
{ {
std::string name() const; std::string name() const;
void apply(program& p) const; void apply(program& p) const;
context get_context() const;
}; };
} // namespace miopen } // namespace miopen
......
...@@ -10,6 +10,11 @@ ...@@ -10,6 +10,11 @@
namespace rtg { namespace rtg {
namespace miopen { namespace miopen {
struct miopen_context
{
shared<miopen_handle> handle;
};
struct miopen_convolution struct miopen_convolution
{ {
convolution op; convolution op;
...@@ -18,46 +23,47 @@ struct miopen_convolution ...@@ -18,46 +23,47 @@ struct miopen_convolution
std::string name() const { return "miopen::convolution"; } std::string name() const { return "miopen::convolution"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(4); check_shapes{inputs, *this}.has(3);
return op.compute_shape({inputs.at(1), inputs.at(2)}); return op.compute_shape({inputs.at(0), inputs.at(1)});
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(context& gctx, shape output_shape, std::vector<argument> args) const
{ {
auto x_desc = make_tensor(args[1].get_shape()); auto& ctx = any_cast<miopen_context>(gctx);
auto w_desc = make_tensor(args[2].get_shape()); auto x_desc = make_tensor(args[0].get_shape());
auto w_desc = make_tensor(args[1].get_shape());
auto y_desc = make_tensor(output_shape); auto y_desc = make_tensor(output_shape);
float alpha = 1, beta = 0; float alpha = 1, beta = 0;
int algo_count; int algo_count;
miopenConvAlgoPerf_t perf; miopenConvAlgoPerf_t perf;
miopenFindConvolutionForwardAlgorithm(args[0].implicit(), miopenFindConvolutionForwardAlgorithm(ctx.handle.get(),
x_desc.get(), x_desc.get(),
args[1].implicit(), args[0].implicit(),
w_desc.get(), w_desc.get(),
args[2].implicit(), args[1].implicit(),
cd.get(), cd.get(),
y_desc.get(), y_desc.get(),
args[3].implicit(), args[2].implicit(),
1, 1,
&algo_count, &algo_count,
&perf, &perf,
nullptr, nullptr,
0, 0,
false); false);
miopenConvolutionForward(args[0].implicit(), miopenConvolutionForward(ctx.handle.get(),
&alpha, &alpha,
x_desc.get(), x_desc.get(),
args[1].implicit(), args[0].implicit(),
w_desc.get(), w_desc.get(),
args[2].implicit(), args[1].implicit(),
cd.get(), cd.get(),
perf.fwd_algo, perf.fwd_algo,
&beta, &beta,
y_desc.get(), y_desc.get(),
args[3].implicit(), args[2].implicit(),
nullptr, nullptr,
0); 0);
return args[3]; return args[2];
} }
}; };
...@@ -69,29 +75,30 @@ struct miopen_pooling ...@@ -69,29 +75,30 @@ struct miopen_pooling
std::string name() const { return "miopen::pooling"; } std::string name() const { return "miopen::pooling"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(3); check_shapes{inputs, *this}.has(2);
return op.compute_shape({inputs.at(1)}); return op.compute_shape({inputs.at(1)});
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(context& gctx, shape output_shape, std::vector<argument> args) const
{ {
auto x_desc = make_tensor(args[1].get_shape()); auto& ctx = any_cast<miopen_context>(gctx);
auto x_desc = make_tensor(args[0].get_shape());
auto y_desc = make_tensor(output_shape); auto y_desc = make_tensor(output_shape);
float alpha = 1, beta = 0; float alpha = 1, beta = 0;
miopenPoolingForward(args[0].implicit(), miopenPoolingForward(ctx.handle.get(),
pd.get(), pd.get(),
&alpha, &alpha,
x_desc.get(), x_desc.get(),
args[1].implicit(), args[0].implicit(),
&beta, &beta,
y_desc.get(), y_desc.get(),
args[2].implicit(), args[1].implicit(),
false, false,
nullptr, nullptr,
0); 0);
return args[2]; return args[1];
} }
}; };
...@@ -100,17 +107,17 @@ struct miopen_add ...@@ -100,17 +107,17 @@ struct miopen_add
std::string name() const { return "miopen::add"; } std::string name() const { return "miopen::add"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(4); check_shapes{inputs, *this}.has(3);
return inputs.at(1); return inputs.at(0);
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(context& gctx, shape output_shape, std::vector<argument> args) const
{ {
if(args[2].get_shape().broadcasted()) if(args[1].get_shape().broadcasted())
{ {
argument result{output_shape}; argument result{output_shape};
visit_all(result, from_gpu(args[1]), from_gpu(args[2]))( visit_all(result, from_gpu(args[0]), from_gpu(args[1]))(
[&](auto output, auto input1, auto input2) { [&](auto output, auto input1, auto input2) {
shape_for_each(output.get_shape(), [&](const auto& idx) { shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = output(idx.begin(), idx.end()) =
...@@ -121,22 +128,23 @@ struct miopen_add ...@@ -121,22 +128,23 @@ struct miopen_add
} }
else else
{ {
auto& ctx = any_cast<miopen_context>(gctx);
float alpha = 1, beta = 0; float alpha = 1, beta = 0;
auto a_desc = make_tensor(args[1].get_shape()); auto a_desc = make_tensor(args[0].get_shape());
auto b_desc = make_tensor(args[2].get_shape()); auto b_desc = make_tensor(args[1].get_shape());
auto c_desc = make_tensor(output_shape); auto c_desc = make_tensor(output_shape);
miopenOpTensor(args[0].implicit(), miopenOpTensor(ctx.handle.get(),
miopenTensorOpAdd, miopenTensorOpAdd,
&alpha, &alpha,
a_desc.get(), a_desc.get(),
args[1].implicit(), args[0].implicit(),
&alpha, &alpha,
b_desc.get(), b_desc.get(),
args[2].implicit(), args[1].implicit(),
&beta, &beta,
c_desc.get(), c_desc.get(),
args[3].implicit()); args[2].implicit());
return args[3]; return args[2];
} }
} }
}; };
...@@ -147,14 +155,14 @@ struct miopen_gemm ...@@ -147,14 +155,14 @@ struct miopen_gemm
std::string name() const { return "miopen::convolution"; } std::string name() const { return "miopen::convolution"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(4); check_shapes{inputs, *this}.has(3);
return op.compute_shape({inputs.at(1), inputs.at(2)}); return op.compute_shape({inputs.at(0), inputs.at(1)});
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
visit_all(result, from_gpu(args[1]), from_gpu(args[2]))( visit_all(result, from_gpu(args[0]), from_gpu(args[1]))(
[&](auto output, auto input1, auto input2) { [&](auto output, auto input1, auto input2) {
dfor(input1.get_shape().lens()[0], dfor(input1.get_shape().lens()[0],
input2.get_shape().lens()[1], input2.get_shape().lens()[1],
...@@ -171,36 +179,36 @@ struct miopen_relu ...@@ -171,36 +179,36 @@ struct miopen_relu
std::string name() const { return "miopen::relu"; } std::string name() const { return "miopen::relu"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(3); check_shapes{inputs, *this}.has(2);
return inputs.at(1); return inputs.at(1);
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(context& gctx, shape output_shape, std::vector<argument> args) const
{ {
auto& ctx = any_cast<miopen_context>(gctx);
float alpha = 1, beta = 0; float alpha = 1, beta = 0;
auto x_desc = make_tensor(args[1].get_shape()); auto x_desc = make_tensor(args[0].get_shape());
auto y_desc = make_tensor(output_shape); auto y_desc = make_tensor(output_shape);
miopenActivationForward(args[0].implicit(), miopenActivationForward(ctx.handle.get(),
ad.get(), ad.get(),
&alpha, &alpha,
x_desc.get(), x_desc.get(),
args[1].implicit(), args[0].implicit(),
&beta, &beta,
y_desc.get(), y_desc.get(),
args[2].implicit()); args[1].implicit());
return args[2]; return args[1];
} }
}; };
struct miopen_apply struct miopen_apply
{ {
program* prog = nullptr; program* prog = nullptr;
instruction_ref handle{};
void apply() void apply()
{ {
handle = prog->add_parameter("handle", shape{shape::any_type}); prog->insert_instruction(prog->begin(), check_context<miopen_context>{});
for(auto it = prog->begin(); it != prog->end(); it++) for(auto it = prog->begin(); it != prog->end(); it++)
{ {
if(it->op.name() == "convolution") if(it->op.name() == "convolution")
...@@ -248,7 +256,6 @@ struct miopen_apply ...@@ -248,7 +256,6 @@ struct miopen_apply
prog->replace_instruction(ins, prog->replace_instruction(ins,
miopen_convolution{op, std::move(cd)}, miopen_convolution{op, std::move(cd)},
handle,
ins->arguments.at(0), ins->arguments.at(0),
ins->arguments.at(1), ins->arguments.at(1),
output); output);
...@@ -261,7 +268,7 @@ struct miopen_apply ...@@ -261,7 +268,7 @@ struct miopen_apply
auto output = insert_allocation(ins, ins->result); auto output = insert_allocation(ins, ins->result);
prog->replace_instruction( prog->replace_instruction(
ins, miopen_pooling{op, std::move(pd)}, handle, ins->arguments.at(0), output); ins, miopen_pooling{op, std::move(pd)}, ins->arguments.at(0), output);
} }
void apply_activation(instruction_ref ins) void apply_activation(instruction_ref ins)
...@@ -272,7 +279,7 @@ struct miopen_apply ...@@ -272,7 +279,7 @@ struct miopen_apply
{ {
auto output = insert_allocation(ins, ins->result); auto output = insert_allocation(ins, ins->result);
prog->replace_instruction( prog->replace_instruction(
ins, miopen_relu{std::move(ad)}, handle, ins->arguments.at(0), output); ins, miopen_relu{std::move(ad)}, ins->arguments.at(0), output);
} }
} }
...@@ -280,7 +287,7 @@ struct miopen_apply ...@@ -280,7 +287,7 @@ struct miopen_apply
{ {
auto output = insert_allocation(ins, ins->result); auto output = insert_allocation(ins, ins->result);
prog->replace_instruction( prog->replace_instruction(
ins, miopen_add{}, handle, ins->arguments.at(0), ins->arguments.at(1), output); ins, miopen_add{}, ins->arguments.at(0), ins->arguments.at(1), output);
} }
void apply_gemm(instruction_ref ins) void apply_gemm(instruction_ref ins)
...@@ -288,7 +295,7 @@ struct miopen_apply ...@@ -288,7 +295,7 @@ struct miopen_apply
auto&& op = any_cast<gemm>(ins->op); auto&& op = any_cast<gemm>(ins->op);
auto output = insert_allocation(ins, ins->result); auto output = insert_allocation(ins, ins->result);
prog->replace_instruction( prog->replace_instruction(
ins, miopen_gemm{op}, handle, ins->arguments.at(0), ins->arguments.at(1), output); ins, miopen_gemm{op}, ins->arguments.at(0), ins->arguments.at(1), output);
} }
}; };
...@@ -296,6 +303,11 @@ std::string miopen_target::name() const { return "miopen"; } ...@@ -296,6 +303,11 @@ std::string miopen_target::name() const { return "miopen"; }
void miopen_target::apply(program& p) const { miopen_apply{&p}.apply(); } void miopen_target::apply(program& p) const { miopen_apply{&p}.apply(); }
context miopen_target::get_context() const
{
return miopen_context{share(make_obj<miopen_handle>(&miopenCreate))};
}
} // namespace miopen } // namespace miopen
} // namespace rtg } // namespace rtg
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
struct sum_op struct sum_op
{ {
std::string name() const { return "sum"; } std::string name() const { return "sum"; }
rtg::argument compute(rtg::shape, std::vector<rtg::argument> args) const rtg::argument compute(rtg::context&, rtg::shape, std::vector<rtg::argument> args) const
{ {
rtg::argument result; rtg::argument result;
if(args.size() != 2) if(args.size() != 2)
...@@ -37,7 +37,7 @@ struct sum_op ...@@ -37,7 +37,7 @@ struct sum_op
struct minus_op struct minus_op
{ {
std::string name() const { return "minus"; } std::string name() const { return "minus"; }
rtg::argument compute(rtg::shape, std::vector<rtg::argument> args) const rtg::argument compute(rtg::context&, rtg::shape, std::vector<rtg::argument> args) const
{ {
rtg::argument result; rtg::argument result;
if(args.size() != 2) if(args.size() != 2)
...@@ -67,6 +67,7 @@ struct id_target ...@@ -67,6 +67,7 @@ struct id_target
{ {
std::string name() const { return "id"; } std::string name() const { return "id"; }
void apply(rtg::program&) const {} void apply(rtg::program&) const {}
rtg::context get_context() const { return {}; }
}; };
void literal_test1() void literal_test1()
......
...@@ -36,8 +36,6 @@ rtg::argument run_gpu() ...@@ -36,8 +36,6 @@ rtg::argument run_gpu()
} }
m["output"] = rtg::miopen::to_gpu(rtg::generate_argument(p.get_parameter_shape("output"))); m["output"] = rtg::miopen::to_gpu(rtg::generate_argument(p.get_parameter_shape("output")));
auto handle = rtg::miopen::make_obj<rtg::miopen::miopen_handle>(&miopenCreate);
m["handle"] = {rtg::shape::any_type, handle.get()};
return rtg::miopen::from_gpu(p.eval(m)); return rtg::miopen::from_gpu(p.eval(m));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment