Unverified Commit 5f1ea74f authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge pull request #9 from ROCmSoftwarePlatform/context

Add context object for execution
parents 7591b7ff a7dcd9fb
pfultz2/rocm-recipes
danmar/cppcheck@d9f9bdda7344e80585f71141be7797055d7987f3
# python/cpython@v3.6.6 -X autotools -H sha256:92aa914572c695c0aeb01b0a214813f414da4b51a371234df514a74761f2bb36
-f requirements.txt
#ifndef RTG_GUARD_BUILTIN_HPP
#define RTG_GUARD_BUILTIN_HPP
#include <rtg/operation.hpp>
#include <rtg/context.hpp>
#include <rtg/errors.hpp>
namespace rtg {
......@@ -12,7 +12,7 @@ struct literal
{
std::string name() const { return "@literal"; }
shape compute_shape(std::vector<shape>) const { RTG_THROW("builtin"); }
argument compute(shape, std::vector<argument>) const { RTG_THROW("builtin"); }
argument compute(context&, shape, std::vector<argument>) const { RTG_THROW("builtin"); }
};
struct outline
......@@ -20,7 +20,7 @@ struct outline
shape s;
std::string name() const { return "@outline"; }
shape compute_shape(std::vector<shape>) const { RTG_THROW("builtin"); }
argument compute(shape, std::vector<argument>) const { RTG_THROW("builtin"); }
argument compute(context&, shape, std::vector<argument>) const { RTG_THROW("builtin"); }
};
struct param
......@@ -28,7 +28,7 @@ struct param
std::string parameter;
std::string name() const { return "@param"; }
shape compute_shape(std::vector<shape>) const { RTG_THROW("builtin"); }
argument compute(shape, std::vector<argument>) const { RTG_THROW("builtin"); }
argument compute(context&, shape, std::vector<argument>) const { RTG_THROW("builtin"); }
friend std::ostream& operator<<(std::ostream& os, const param& op)
{
os << op.name() << ":" << op.parameter;
......
#ifndef RTG_GUARD_CONTEXT_HPP
#define RTG_GUARD_CONTEXT_HPP
namespace rtg {
/*
* Type-erased interface for:
*
* struct context
* {
* };
*
*/
struct context
{
// Constructors
context() = default;
template <typename PrivateDetailTypeErasedT>
context(PrivateDetailTypeErasedT value)
: private_detail_te_handle_mem_var(
std::make_shared<private_detail_te_handle_type<
typename std::remove_reference<PrivateDetailTypeErasedT>::type>>(
std::forward<PrivateDetailTypeErasedT>(value)))
{
}
// Assignment
template <typename PrivateDetailTypeErasedT>
context& operator=(PrivateDetailTypeErasedT value)
{
if(private_detail_te_handle_mem_var.unique())
*private_detail_te_handle_mem_var = std::forward<PrivateDetailTypeErasedT>(value);
else if(!private_detail_te_handle_mem_var)
private_detail_te_handle_mem_var = std::make_shared<PrivateDetailTypeErasedT>(
std::forward<PrivateDetailTypeErasedT>(value));
return *this;
}
// Cast
template <typename PrivateDetailTypeErasedT>
PrivateDetailTypeErasedT* any_cast()
{
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
.private_detail_te_value)
: nullptr;
}
template <typename PrivateDetailTypeErasedT>
const typename std::remove_cv<PrivateDetailTypeErasedT>::type* any_cast() const
{
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<const private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
.private_detail_te_value)
: nullptr;
}
const std::type_info& type_id() const
{
if(private_detail_te_handle_empty())
return typeid(std::nullptr_t);
else
return private_detail_te_get_handle().type();
}
private:
struct private_detail_te_handle_base_type
{
virtual ~private_detail_te_handle_base_type() {}
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0;
};
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type
{
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<std::is_reference<PrivateDetailTypeErasedU>::value>::type* =
nullptr)
: private_detail_te_value(value)
{
}
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value))
{
}
std::shared_ptr<private_detail_te_handle_base_type> clone() const override
{
return std::make_shared<private_detail_te_handle_type>(private_detail_te_value);
}
const std::type_info& type() const override { return typeid(private_detail_te_value); }
PrivateDetailTypeErasedT private_detail_te_value;
};
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type<std::reference_wrapper<PrivateDetailTypeErasedT>>
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>
{
private_detail_te_handle_type(std::reference_wrapper<PrivateDetailTypeErasedT> ref)
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>(ref.get())
{
}
};
bool private_detail_te_handle_empty() const
{
return private_detail_te_handle_mem_var == nullptr;
}
const private_detail_te_handle_base_type& private_detail_te_get_handle() const
{
assert(private_detail_te_handle_mem_var != nullptr);
return *private_detail_te_handle_mem_var;
}
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
std::shared_ptr<private_detail_te_handle_base_type> private_detail_te_handle_mem_var;
};
template <typename ValueType>
inline const ValueType* any_cast(const context* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType* any_cast(context* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType& any_cast(context& x)
{
auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
throw std::bad_cast();
return *y;
}
template <typename ValueType>
inline const ValueType& any_cast(const context& x)
{
const auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
throw std::bad_cast();
return *y;
}
} // namespace rtg
#endif
......@@ -8,6 +8,7 @@
#include <utility>
#include <rtg/shape.hpp>
#include <rtg/argument.hpp>
#include <rtg/context.hpp>
namespace rtg {
......@@ -28,7 +29,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
* {
* std::string name() const;
* shape compute_shape(std::vector<shape> input) const;
* argument compute(shape output,std::vector<argument> input) const;
* argument compute(context& ctx,shape output,std::vector<argument> input) const;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* };
*
......@@ -83,6 +84,14 @@ struct operation
: nullptr;
}
const std::type_info& type_id() const
{
if(private_detail_te_handle_empty())
return typeid(std::nullptr_t);
else
return private_detail_te_get_handle().type();
}
std::string name() const
{
assert((*this).private_detail_te_handle_mem_var);
......@@ -95,10 +104,11 @@ struct operation
return (*this).private_detail_te_get_handle().compute_shape(std::move(input));
}
argument compute(shape output, std::vector<argument> input) const
argument compute(context& ctx, shape output, std::vector<argument> input) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().compute(std::move(output), std::move(input));
return (*this).private_detail_te_get_handle().compute(
ctx, std::move(output), std::move(input));
}
friend std::ostream& operator<<(std::ostream& os, const operation& op)
......@@ -116,7 +126,7 @@ struct operation
virtual std::string name() const = 0;
virtual shape compute_shape(std::vector<shape> input) const = 0;
virtual argument compute(shape output, std::vector<argument> input) const = 0;
virtual argument compute(context& ctx, shape output, std::vector<argument> input) const = 0;
virtual std::ostream& operator_shift_left(std::ostream& os) const = 0;
};
......@@ -156,10 +166,10 @@ struct operation
return private_detail_te_value.compute_shape(std::move(input));
}
argument compute(shape output, std::vector<argument> input) const override
argument compute(context& ctx, shape output, std::vector<argument> input) const override
{
return private_detail_te_value.compute(std::move(output), std::move(input));
return private_detail_te_value.compute(ctx, std::move(output), std::move(input));
}
std::ostream& operator_shift_left(std::ostream& os) const override
......@@ -181,13 +191,20 @@ struct operation
}
};
bool private_detail_te_handle_empty() const
{
return private_detail_te_handle_mem_var == nullptr;
}
const private_detail_te_handle_base_type& private_detail_te_get_handle() const
{
assert(private_detail_te_handle_mem_var != nullptr);
return *private_detail_te_handle_mem_var;
}
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
......
......@@ -12,15 +12,29 @@ namespace rtg {
struct check_shapes
{
const std::vector<shape>* shapes;
const std::string name;
check_shapes(const std::vector<shape>& s) : shapes(&s) {}
template <class Op>
check_shapes(const std::vector<shape>& s, const Op& op) : shapes(&s), name(op.name())
{
}
std::string prefix() const
{
if(name.empty())
return "";
else
return name + ": ";
}
const check_shapes& has(std::size_t n) const
{
assert(shapes != nullptr);
if(shapes->size() != n)
RTG_THROW("Wrong number of arguments: expected " + std::to_string(n) + " but given " +
std::to_string(shapes->size()));
RTG_THROW(prefix() + "Wrong number of arguments: expected " + std::to_string(n) +
" but given " + std::to_string(shapes->size()));
return *this;
}
......@@ -30,7 +44,7 @@ struct check_shapes
if(!shapes->empty())
{
if(shapes->front().lens().size() != n)
RTG_THROW("Only " + std::to_string(n) + "d supported");
RTG_THROW(prefix() + "Only " + std::to_string(n) + "d supported");
}
return *this;
}
......@@ -38,28 +52,28 @@ struct check_shapes
const check_shapes& same_shape() const
{
if(!this->same([](const shape& s) { return s; }))
RTG_THROW("Shapes do not match");
RTG_THROW(prefix() + "Shapes do not match");
return *this;
}
const check_shapes& same_type() const
{
if(!this->same([](const shape& s) { return s.type(); }))
RTG_THROW("Types do not match");
RTG_THROW(prefix() + "Types do not match");
return *this;
}
const check_shapes& same_dims() const
{
if(!this->same([](const shape& s) { return s.lens(); }))
RTG_THROW("Dimensions do not match");
RTG_THROW(prefix() + "Dimensions do not match");
return *this;
}
const check_shapes& same_ndims() const
{
if(!this->same([](const shape& s) { return s.lens().size(); }))
RTG_THROW("Dimensions do not match");
RTG_THROW(prefix() + "Dimensions do not match");
return *this;
}
......@@ -83,7 +97,7 @@ struct check_shapes
struct not_computable
{
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
argument compute(context&, shape, std::vector<argument>) const { RTG_THROW("not computable"); }
};
struct convolution
......@@ -101,7 +115,7 @@ struct convolution
std::string name() const { return "convolution"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(2).same_type().same_ndims().only_dims(4);
check_shapes{inputs, *this}.has(2).same_type().same_ndims().only_dims(4);
const shape& input = inputs.at(0);
const shape& weights = inputs.at(1);
......@@ -153,7 +167,7 @@ struct convolution
}
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
argument compute(context&, shape, std::vector<argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const convolution& op)
{
......@@ -175,7 +189,7 @@ struct pooling
std::string name() const { return "pooling"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1).only_dims(4);
check_shapes{inputs, *this}.has(1).only_dims(4);
const shape& input = inputs.at(0);
auto t = input.type();
......@@ -200,7 +214,7 @@ struct pooling
}};
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
argument compute(context&, shape, std::vector<argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const pooling& op)
{
......@@ -219,11 +233,11 @@ struct activation
std::string name() const { return "activation"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
check_shapes{inputs, *this}.has(1);
return inputs.front();
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
argument compute(context&, shape, std::vector<argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const activation& op)
{
os << op.name() << ":" << op.mode;
......@@ -237,7 +251,7 @@ struct transpose
std::string name() const { return "transpose"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
check_shapes{inputs, *this}.has(1);
auto input = inputs.at(0);
auto input_lens = input.lens();
auto input_strides = input.strides();
......@@ -261,7 +275,7 @@ struct transpose
}
return {t, output_lens, output_strides};
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
argument compute(context&, shape, std::vector<argument>) const { RTG_THROW("not computable"); }
};
struct contiguous
......@@ -269,7 +283,7 @@ struct contiguous
std::string name() const { return "contiguous"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
check_shapes{inputs, *this}.has(1);
auto lens = inputs.at(0).lens();
auto t = inputs.at(0).type();
if(lens.size() < 2)
......@@ -278,7 +292,7 @@ struct contiguous
}
return {t, lens};
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
argument compute(context&, shape, std::vector<argument>) const { RTG_THROW("not computable"); }
};
struct reshape
......@@ -287,7 +301,7 @@ struct reshape
std::string name() const { return "reshape"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
check_shapes{inputs, *this}.has(1);
auto&& idims = inputs.front().lens();
std::vector<std::size_t> rdims(dims.begin(), dims.end());
for(std::size_t i = 0; i < dims.size(); i++)
......@@ -306,7 +320,7 @@ struct reshape
return s;
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
argument compute(context&, shape, std::vector<argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const reshape& op)
{
......@@ -322,7 +336,7 @@ struct gemm
std::string name() const { return "gemm"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(2).same_type();
check_shapes{inputs, *this}.has(2).same_type();
const shape& a = inputs.at(0);
const shape& b = inputs.at(1);
auto t = a.type();
......@@ -332,7 +346,7 @@ struct gemm
return {t, {a.lens()[0], b.lens()[1]}};
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
argument compute(context&, shape, std::vector<argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const gemm& op)
{
......@@ -349,7 +363,7 @@ struct unary
check_shapes{inputs}.has(1);
return inputs.at(0);
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
argument compute(context&, shape, std::vector<argument>) const { RTG_THROW("not computable"); }
};
struct identity : unary
......@@ -449,7 +463,7 @@ struct broadcast
return {t, result.lens(), std::move(bcast_strides)};
}
}
argument compute(shape output_shape, std::vector<argument> args) const
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
return {output_shape, std::move(args.at(1).data)};
}
......@@ -463,7 +477,7 @@ struct binary
check_shapes{inputs}.has(2).same_type().same_dims();
return inputs.at(0);
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
argument compute(context&, shape, std::vector<argument>) const { RTG_THROW("not computable"); }
};
struct add : binary
......@@ -492,10 +506,24 @@ struct outline
std::string name() const { return "outline"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(0);
check_shapes{inputs, *this}.has(0);
return s;
}
argument compute(shape, std::vector<argument>) const { return {s, nullptr}; }
argument compute(context&, shape, std::vector<argument>) const { return {s, nullptr}; }
};
template <class T>
struct check_context
{
std::string name() const { return "check_context"; }
shape compute_shape(std::vector<shape>) const { return {}; }
argument compute(context& ctx, shape, std::vector<argument>) const
{
T* x = any_cast<T>(&ctx);
if(x == nullptr)
RTG_THROW(std::string("Unexpected context type: ") + ctx.type_id().name());
return {};
}
};
} // namespace rtg
......
......@@ -89,13 +89,27 @@ struct raw_data : raw_data_base
assert(self->single());
return self->template at<T>();
}
template <class T>
using is_data_ptr =
bool_c<(std::is_void<T>{} or std::is_same<char, std::remove_cv_t<T>>{} or
std::is_same<unsigned char, std::remove_cv_t<T>>{})>;
template <class T>
using get_data_type = std::conditional_t<is_data_ptr<T>{}, float, T>;
template <class T>
bool matches() const
{
return is_data_ptr<T>{} ||
self->get_shape().type() == rtg::shape::get_type<get_data_type<T>>{};
}
template <class T>
operator T*()
{
using type = std::remove_cv_t<T>;
assert((std::is_void<T>{} or std::is_same<char, type>{} or
std::is_same<unsigned char, type>{} or
self->get_shape().type() == rtg::shape::get_type<T>{}));
assert(matches<T>());
return reinterpret_cast<type*>(self->data());
}
};
......
......@@ -10,10 +10,15 @@ struct and_ : std::is_same<and_<Bs...>, and_<(Bs || true)...>> // NOLINT
{
};
template <bool B>
using bool_c = std::integral_constant<bool, B>;
#ifdef CPPCHECK
#define RTG_REQUIRES(...) class = void
#else
#define RTG_REQUIRES(...) class = typename std::enable_if<and_<__VA_ARGS__, true>{}>::type
#define RTG_REQUIRES(...) \
bool PrivateRequires##__LINE__ = true, \
class = typename std::enable_if<and_<__VA_ARGS__, PrivateRequires##__LINE__>{}>::type
#endif
} // namespace rtg
......
......@@ -31,15 +31,12 @@ struct shape
#define RTG_SHAPE_ENUM_TYPES(x, t) x,
enum type_t
{
any_type,
RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_ENUM_TYPES)
};
#undef RTG_SHAPE_ENUM_TYPES
template <class T, class = void>
struct get_type : std::integral_constant<type_t, any_type>
{
};
struct get_type;
#define RTG_SHAPE_GET_TYPE(x, t) \
template <class T> \
struct get_type<t, T> : std::integral_constant<type_t, x> \
......@@ -125,7 +122,6 @@ struct shape
{
switch(this->m_type)
{
case any_type: RTG_THROW("Cannot visit the any_type");
#define RTG_SHAPE_VISITOR_CASE(x, t) \
case x: v(as<t>()); return;
RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_VISITOR_CASE)
......
......@@ -6,6 +6,7 @@
#include <memory>
#include <type_traits>
#include <utility>
#include <rtg/context.hpp>
namespace rtg {
......@@ -18,6 +19,7 @@ struct program;
* {
* std::string name() const;
* void apply(program & p) const;
* context get_context() const;
* };
*
*/
......@@ -71,6 +73,14 @@ struct target
: nullptr;
}
const std::type_info& type_id() const
{
if(private_detail_te_handle_empty())
return typeid(std::nullptr_t);
else
return private_detail_te_get_handle().type();
}
std::string name() const
{
assert((*this).private_detail_te_handle_mem_var);
......@@ -83,6 +93,12 @@ struct target
return (*this).private_detail_te_get_handle().apply(p);
}
context get_context() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().get_context();
}
private:
struct private_detail_te_handle_base_type
{
......@@ -92,6 +108,7 @@ struct target
virtual std::string name() const = 0;
virtual void apply(program& p) const = 0;
virtual context get_context() const = 0;
};
template <typename PrivateDetailTypeErasedT>
......@@ -126,6 +143,8 @@ struct target
void apply(program& p) const override { return private_detail_te_value.apply(p); }
context get_context() const override { return private_detail_te_value.get_context(); }
PrivateDetailTypeErasedT private_detail_te_value;
};
......@@ -139,13 +158,20 @@ struct target
}
};
bool private_detail_te_handle_empty() const
{
return private_detail_te_handle_mem_var == nullptr;
}
const private_detail_te_handle_base_type& private_detail_te_get_handle() const
{
assert(private_detail_te_handle_mem_var != nullptr);
return *private_detail_te_handle_mem_var;
}
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
......
......@@ -25,7 +25,7 @@ struct unknown
else
return input.front();
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
argument compute(context&, shape, std::vector<argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const unknown& x)
{
os << x.name();
......
......@@ -29,8 +29,7 @@ rtg::argument run_gpu(std::string file)
auto output = rtg::miopen::to_gpu(rtg::generate_argument(p.get_parameter_shape("output")));
auto handle = rtg::miopen::make_obj<rtg::miopen::miopen_handle>(&miopenCreate);
auto out = p.eval(
{{"Input3", input3}, {"handle", {rtg::shape::any_type, handle.get()}}, {"output", output}});
auto out = p.eval({{"Input3", input3}, {"output", output}});
std::cout << p << std::endl;
return rtg::miopen::from_gpu(out);
}
......
......@@ -10,6 +10,7 @@ struct program_impl
{
// A list is used to keep references to an instruction stable
std::list<instruction> instructions;
context ctx;
};
const operation& get_operation(instruction_ref ins) { return ins->op; }
......@@ -109,6 +110,7 @@ instruction_ref program::validate() const
void program::compile(const target& t)
{
assert(this->validate() != impl->instructions.end());
this->impl->ctx = t.get_context();
t.apply(*this);
if(this->validate() == impl->instructions.end())
RTG_THROW("Invalid program from compilation");
......@@ -140,7 +142,7 @@ argument program::eval(std::unordered_map<std::string, argument> params) const
ins.arguments.end(),
values.begin(),
[&](instruction_ref i) { return results.at(std::addressof(*i)); });
result = ins.op.compute(ins.result, values);
result = ins.op.compute(this->impl->ctx, ins.result, values);
}
results.emplace(std::addressof(ins), result);
}
......
......@@ -109,7 +109,6 @@ std::string shape::type_string() const
{
switch(this->m_type)
{
case any_type: return "any";
#define RTG_SHAPE_TYPE_STRING_CASE(x, t) \
case x: return #x;
RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_TYPE_STRING_CASE)
......
......@@ -20,7 +20,7 @@ struct cpu_convolution
std::string name() const { return "cpu::convolution"; }
shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); }
argument compute(shape output_shape, std::vector<argument> args) const
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input, auto weights) {
......@@ -86,7 +86,7 @@ struct cpu_pooling
std::string name() const { return "cpu::pooling_" + Op::name(); }
shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); }
argument compute(shape output_shape, std::vector<argument> args) const
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
......@@ -134,7 +134,7 @@ struct cpu_transpose
std::string name() const { return "cpu::transpose"; }
shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); }
argument compute(shape output_shape, std::vector<argument> args) const
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
return {output_shape, std::move(args.front().data)};
}
......@@ -145,7 +145,7 @@ struct cpu_contiguous
contiguous op;
std::string name() const { return "cpu::contiguous"; }
shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); }
argument compute(shape output_shape, std::vector<argument> args) const
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
......@@ -163,7 +163,7 @@ struct cpu_reshape
std::string name() const { return "cpu::reshape"; }
shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); }
argument compute(shape output_shape, std::vector<argument> args) const
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
return {output_shape, std::move(args.front().data)};
}
......@@ -175,7 +175,7 @@ struct cpu_gemm
std::string name() const { return "cpu::gemm"; }
shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); }
argument compute(shape output_shape, std::vector<argument> args) const
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto cmat, auto amat, auto bmat) {
......@@ -334,7 +334,7 @@ struct cpu_unary
Op op;
std::string name() const { return op.name(); }
shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); }
argument compute(shape output_shape, std::vector<argument> args) const
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
argument result{output_shape};
result.visit([&](auto output) {
......@@ -350,7 +350,7 @@ struct softmax2d
{
std::string name() const { return "cpu::softmax2d"; }
shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); }
argument compute(shape output_shape, std::vector<argument> args) const
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
......@@ -426,7 +426,7 @@ struct cpu_binary
Op op;
std::string name() const { return op.name(); }
shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); }
argument compute(shape output_shape, std::vector<argument> args) const
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
......
......@@ -10,6 +10,7 @@ struct cpu_target
{
std::string name() const;
void apply(program& p) const;
context get_context() const { return {}; }
};
} // namespace cpu
......
......@@ -20,7 +20,7 @@ struct hip_allocate
check_shapes{inputs}.has(1);
return inputs.front();
}
argument compute(shape output_shape, std::vector<argument>) const
argument compute(context&, shape output_shape, std::vector<argument>) const
{
return allocate_gpu(output_shape);
}
......
......@@ -10,6 +10,7 @@ struct miopen_target
{
std::string name() const;
void apply(program& p) const;
context get_context() const;
};
} // namespace miopen
......
......@@ -10,6 +10,11 @@
namespace rtg {
namespace miopen {
struct miopen_context
{
shared<miopen_handle> handle;
};
struct miopen_convolution
{
convolution op;
......@@ -18,46 +23,47 @@ struct miopen_convolution
std::string name() const { return "miopen::convolution"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(4);
return op.compute_shape({inputs.at(1), inputs.at(2)});
check_shapes{inputs, *this}.has(3);
return op.compute_shape({inputs.at(0), inputs.at(1)});
}
argument compute(shape output_shape, std::vector<argument> args) const
argument compute(context& gctx, shape output_shape, std::vector<argument> args) const
{
auto x_desc = make_tensor(args[1].get_shape());
auto w_desc = make_tensor(args[2].get_shape());
auto& ctx = any_cast<miopen_context>(gctx);
auto x_desc = make_tensor(args[0].get_shape());
auto w_desc = make_tensor(args[1].get_shape());
auto y_desc = make_tensor(output_shape);
float alpha = 1, beta = 0;
int algo_count;
miopenConvAlgoPerf_t perf;
miopenFindConvolutionForwardAlgorithm(args[0].implicit(),
miopenFindConvolutionForwardAlgorithm(ctx.handle.get(),
x_desc.get(),
args[1].implicit(),
args[0].implicit(),
w_desc.get(),
args[2].implicit(),
args[1].implicit(),
cd.get(),
y_desc.get(),
args[3].implicit(),
args[2].implicit(),
1,
&algo_count,
&perf,
nullptr,
0,
false);
miopenConvolutionForward(args[0].implicit(),
miopenConvolutionForward(ctx.handle.get(),
&alpha,
x_desc.get(),
args[1].implicit(),
args[0].implicit(),
w_desc.get(),
args[2].implicit(),
args[1].implicit(),
cd.get(),
perf.fwd_algo,
&beta,
y_desc.get(),
args[3].implicit(),
args[2].implicit(),
nullptr,
0);
return args[3];
return args[2];
}
};
......@@ -69,29 +75,30 @@ struct miopen_pooling
std::string name() const { return "miopen::pooling"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(3);
check_shapes{inputs, *this}.has(2);
return op.compute_shape({inputs.at(1)});
}
argument compute(shape output_shape, std::vector<argument> args) const
argument compute(context& gctx, shape output_shape, std::vector<argument> args) const
{
auto x_desc = make_tensor(args[1].get_shape());
auto& ctx = any_cast<miopen_context>(gctx);
auto x_desc = make_tensor(args[0].get_shape());
auto y_desc = make_tensor(output_shape);
float alpha = 1, beta = 0;
miopenPoolingForward(args[0].implicit(),
miopenPoolingForward(ctx.handle.get(),
pd.get(),
&alpha,
x_desc.get(),
args[1].implicit(),
args[0].implicit(),
&beta,
y_desc.get(),
args[2].implicit(),
args[1].implicit(),
false,
nullptr,
0);
return args[2];
return args[1];
}
};
......@@ -100,17 +107,17 @@ struct miopen_add
std::string name() const { return "miopen::add"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(4);
return inputs.at(1);
check_shapes{inputs, *this}.has(3);
return inputs.at(0);
}
argument compute(shape output_shape, std::vector<argument> args) const
argument compute(context& gctx, shape output_shape, std::vector<argument> args) const
{
if(args[2].get_shape().broadcasted())
if(args[1].get_shape().broadcasted())
{
argument result{output_shape};
visit_all(result, from_gpu(args[1]), from_gpu(args[2]))(
visit_all(result, from_gpu(args[0]), from_gpu(args[1]))(
[&](auto output, auto input1, auto input2) {
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) =
......@@ -121,22 +128,23 @@ struct miopen_add
}
else
{
auto& ctx = any_cast<miopen_context>(gctx);
float alpha = 1, beta = 0;
auto a_desc = make_tensor(args[1].get_shape());
auto b_desc = make_tensor(args[2].get_shape());
auto a_desc = make_tensor(args[0].get_shape());
auto b_desc = make_tensor(args[1].get_shape());
auto c_desc = make_tensor(output_shape);
miopenOpTensor(args[0].implicit(),
miopenOpTensor(ctx.handle.get(),
miopenTensorOpAdd,
&alpha,
a_desc.get(),
args[1].implicit(),
args[0].implicit(),
&alpha,
b_desc.get(),
args[2].implicit(),
args[1].implicit(),
&beta,
c_desc.get(),
args[3].implicit());
return args[3];
args[2].implicit());
return args[2];
}
}
};
......@@ -147,14 +155,14 @@ struct miopen_gemm
std::string name() const { return "miopen::convolution"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(4);
return op.compute_shape({inputs.at(1), inputs.at(2)});
check_shapes{inputs, *this}.has(3);
return op.compute_shape({inputs.at(0), inputs.at(1)});
}
argument compute(shape output_shape, std::vector<argument> args) const
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, from_gpu(args[1]), from_gpu(args[2]))(
visit_all(result, from_gpu(args[0]), from_gpu(args[1]))(
[&](auto output, auto input1, auto input2) {
dfor(input1.get_shape().lens()[0],
input2.get_shape().lens()[1],
......@@ -171,36 +179,36 @@ struct miopen_relu
std::string name() const { return "miopen::relu"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(3);
check_shapes{inputs, *this}.has(2);
return inputs.at(1);
}
argument compute(shape output_shape, std::vector<argument> args) const
argument compute(context& gctx, shape output_shape, std::vector<argument> args) const
{
auto& ctx = any_cast<miopen_context>(gctx);
float alpha = 1, beta = 0;
auto x_desc = make_tensor(args[1].get_shape());
auto x_desc = make_tensor(args[0].get_shape());
auto y_desc = make_tensor(output_shape);
miopenActivationForward(args[0].implicit(),
miopenActivationForward(ctx.handle.get(),
ad.get(),
&alpha,
x_desc.get(),
args[1].implicit(),
args[0].implicit(),
&beta,
y_desc.get(),
args[2].implicit());
args[1].implicit());
return args[2];
return args[1];
}
};
struct miopen_apply
{
program* prog = nullptr;
instruction_ref handle{};
void apply()
{
handle = prog->add_parameter("handle", shape{shape::any_type});
prog->insert_instruction(prog->begin(), check_context<miopen_context>{});
for(auto it = prog->begin(); it != prog->end(); it++)
{
if(it->op.name() == "convolution")
......@@ -248,7 +256,6 @@ struct miopen_apply
prog->replace_instruction(ins,
miopen_convolution{op, std::move(cd)},
handle,
ins->arguments.at(0),
ins->arguments.at(1),
output);
......@@ -261,7 +268,7 @@ struct miopen_apply
auto output = insert_allocation(ins, ins->result);
prog->replace_instruction(
ins, miopen_pooling{op, std::move(pd)}, handle, ins->arguments.at(0), output);
ins, miopen_pooling{op, std::move(pd)}, ins->arguments.at(0), output);
}
void apply_activation(instruction_ref ins)
......@@ -272,7 +279,7 @@ struct miopen_apply
{
auto output = insert_allocation(ins, ins->result);
prog->replace_instruction(
ins, miopen_relu{std::move(ad)}, handle, ins->arguments.at(0), output);
ins, miopen_relu{std::move(ad)}, ins->arguments.at(0), output);
}
}
......@@ -280,7 +287,7 @@ struct miopen_apply
{
auto output = insert_allocation(ins, ins->result);
prog->replace_instruction(
ins, miopen_add{}, handle, ins->arguments.at(0), ins->arguments.at(1), output);
ins, miopen_add{}, ins->arguments.at(0), ins->arguments.at(1), output);
}
void apply_gemm(instruction_ref ins)
......@@ -288,7 +295,7 @@ struct miopen_apply
auto&& op = any_cast<gemm>(ins->op);
auto output = insert_allocation(ins, ins->result);
prog->replace_instruction(
ins, miopen_gemm{op}, handle, ins->arguments.at(0), ins->arguments.at(1), output);
ins, miopen_gemm{op}, ins->arguments.at(0), ins->arguments.at(1), output);
}
};
......@@ -296,6 +303,11 @@ std::string miopen_target::name() const { return "miopen"; }
void miopen_target::apply(program& p) const { miopen_apply{&p}.apply(); }
context miopen_target::get_context() const
{
return miopen_context{share(make_obj<miopen_handle>(&miopenCreate))};
}
} // namespace miopen
} // namespace rtg
......@@ -8,7 +8,7 @@
struct sum_op
{
std::string name() const { return "sum"; }
rtg::argument compute(rtg::shape, std::vector<rtg::argument> args) const
rtg::argument compute(rtg::context&, rtg::shape, std::vector<rtg::argument> args) const
{
rtg::argument result;
if(args.size() != 2)
......@@ -37,7 +37,7 @@ struct sum_op
struct minus_op
{
std::string name() const { return "minus"; }
rtg::argument compute(rtg::shape, std::vector<rtg::argument> args) const
rtg::argument compute(rtg::context&, rtg::shape, std::vector<rtg::argument> args) const
{
rtg::argument result;
if(args.size() != 2)
......@@ -67,6 +67,7 @@ struct id_target
{
std::string name() const { return "id"; }
void apply(rtg::program&) const {}
rtg::context get_context() const { return {}; }
};
void literal_test1()
......
......@@ -36,8 +36,6 @@ rtg::argument run_gpu()
}
m["output"] = rtg::miopen::to_gpu(rtg::generate_argument(p.get_parameter_shape("output")));
auto handle = rtg::miopen::make_obj<rtg::miopen::miopen_handle>(&miopenCreate);
m["handle"] = {rtg::shape::any_type, handle.get()};
return rtg::miopen::from_gpu(p.eval(m));
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment