Commit 2211ade7 authored by Paul's avatar Paul
Browse files

Fix compile errors

parent dc21cca1
...@@ -35,6 +35,12 @@ struct argument : raw_data<argument> ...@@ -35,6 +35,12 @@ struct argument : raw_data<argument>
const shape& get_shape() const { return this->m_shape; } const shape& get_shape() const { return this->m_shape; }
template<class T>
T* cast() const
{
return reinterpret_cast<T*>(this->data());
}
private: private:
shape m_shape; shape m_shape;
}; };
......
...@@ -94,6 +94,27 @@ struct raw_data ...@@ -94,6 +94,27 @@ struct raw_data
this->visit_at([&](auto x) { result = x; }, n); this->visit_at([&](auto x) { result = x; }, n);
return result; return result;
} }
struct auto_cast
{
const Derived * self;
template<class T>
operator T()
{
return self->template at<T>();
}
template<class T>
operator T*()
{
// TODO: Check type
return reinterpret_cast<T*>(self->data());
}
};
auto_cast get() const
{
return {static_cast<const Derived*>(this)};
}
}; };
namespace detail { namespace detail {
......
...@@ -30,12 +30,14 @@ struct shape ...@@ -30,12 +30,14 @@ struct shape
#define RTG_SHAPE_ENUM_TYPES(x, t) x, #define RTG_SHAPE_ENUM_TYPES(x, t) x,
enum type_t enum type_t
{ {
any_type,
RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_ENUM_TYPES) RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_ENUM_TYPES)
}; };
#undef RTG_SHAPE_ENUM_TYPES #undef RTG_SHAPE_ENUM_TYPES
template <class T, class = void> template <class T, class = void>
struct get_type; struct get_type : std::integral_constant<type_t, any_type>
{};
#define RTG_SHAPE_GET_TYPE(x, t) \ #define RTG_SHAPE_GET_TYPE(x, t) \
template <class T> \ template <class T> \
struct get_type<t, T> : std::integral_constant<type_t, x> \ struct get_type<t, T> : std::integral_constant<type_t, x> \
...@@ -112,6 +114,7 @@ struct shape ...@@ -112,6 +114,7 @@ struct shape
{ {
switch(this->m_type) switch(this->m_type)
{ {
case any_type: RTG_THROW("Cannot visit the any_type");
#define RTG_SHAPE_VISITOR_CASE(x, t) \ #define RTG_SHAPE_VISITOR_CASE(x, t) \
case x: v(as<t>()); return; case x: v(as<t>()); return;
RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_VISITOR_CASE) RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_VISITOR_CASE)
......
...@@ -91,6 +91,7 @@ std::string shape::type_string() const ...@@ -91,6 +91,7 @@ std::string shape::type_string() const
{ {
switch(this->m_type) switch(this->m_type)
{ {
case any_type: return "any";
#define RTG_SHAPE_TYPE_STRING_CASE(x, t) \ #define RTG_SHAPE_TYPE_STRING_CASE(x, t) \
case x: return #x; case x: return #x;
RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_TYPE_STRING_CASE) RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_TYPE_STRING_CASE)
......
...@@ -72,34 +72,35 @@ struct miopen_convolution ...@@ -72,34 +72,35 @@ struct miopen_convolution
auto w_desc = make_tensor(args[2].get_shape()); auto w_desc = make_tensor(args[2].get_shape());
auto y_desc = make_tensor(output_shape); auto y_desc = make_tensor(output_shape);
float alpha = 1, beta = 0;
int algo_count; int algo_count;
miopenConvAlgoPerf_t perf; miopenConvAlgoPerf_t perf;
miopenFindConvolutionForwardAlgorithm(args[0].data(), miopenFindConvolutionForwardAlgorithm(args[0].get(),
x_desc.get(), x_desc.get(),
args[1].data(), args[1].get(),
w_desc, w_desc.get(),
args[2].data(), args[2].get(),
cd.get(), cd.get(),
y_desc, y_desc.get(),
args[4].data(), args[4].get(),
1, 1,
&algo_count, &algo_count,
&perf, &perf,
args[3].data(), args[3].get(),
args[3].get_shape().bytes(), args[3].get_shape().bytes(),
false); false);
miopenConvolutionForward(args[0].data(), miopenConvolutionForward(args[0].get(),
&alpha, &alpha,
x_desc, x_desc.get(),
args[1].data(), args[1].get(),
w_desc, w_desc.get(),
args[2].data(), args[2].get(),
cd.get(), cd.get(),
perf.fwd_algo, perf.fwd_algo,
&beta, &beta,
y_desc, y_desc.get(),
args[4].data(), args[4].get(),
args[3].data(), args[3].get(),
args[3].get_shape().bytes()); args[3].get_shape().bytes());
return result; return result;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment