Commit 2211ade7 authored by Paul's avatar Paul
Browse files

Fix compile errors

parent dc21cca1
......@@ -35,6 +35,12 @@ struct argument : raw_data<argument>
const shape& get_shape() const { return this->m_shape; }
template<class T>
T* cast() const
{
return reinterpret_cast<T*>(this->data());
}
private:
shape m_shape;
};
......
......@@ -94,6 +94,27 @@ struct raw_data
this->visit_at([&](auto x) { result = x; }, n);
return result;
}
struct auto_cast
{
const Derived * self;
template<class T>
operator T()
{
return self->template at<T>();
}
template<class T>
operator T*()
{
// TODO: Check type
return reinterpret_cast<T*>(self->data());
}
};
auto_cast get() const
{
return {static_cast<const Derived*>(this)};
}
};
namespace detail {
......
......@@ -30,12 +30,14 @@ struct shape
#define RTG_SHAPE_ENUM_TYPES(x, t) x,
enum type_t
{
any_type,
RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_ENUM_TYPES)
};
#undef RTG_SHAPE_ENUM_TYPES
template <class T, class = void>
struct get_type;
struct get_type : std::integral_constant<type_t, any_type>
{};
#define RTG_SHAPE_GET_TYPE(x, t) \
template <class T> \
struct get_type<t, T> : std::integral_constant<type_t, x> \
......@@ -112,6 +114,7 @@ struct shape
{
switch(this->m_type)
{
case any_type: RTG_THROW("Cannot visit the any_type");
#define RTG_SHAPE_VISITOR_CASE(x, t) \
case x: v(as<t>()); return;
RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_VISITOR_CASE)
......
......@@ -91,6 +91,7 @@ std::string shape::type_string() const
{
switch(this->m_type)
{
case any_type: return "any";
#define RTG_SHAPE_TYPE_STRING_CASE(x, t) \
case x: return #x;
RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_TYPE_STRING_CASE)
......
......@@ -72,34 +72,35 @@ struct miopen_convolution
auto w_desc = make_tensor(args[2].get_shape());
auto y_desc = make_tensor(output_shape);
float alpha = 1, beta = 0;
int algo_count;
miopenConvAlgoPerf_t perf;
miopenFindConvolutionForwardAlgorithm(args[0].data(),
miopenFindConvolutionForwardAlgorithm(args[0].get(),
x_desc.get(),
args[1].data(),
w_desc,
args[2].data(),
args[1].get(),
w_desc.get(),
args[2].get(),
cd.get(),
y_desc,
args[4].data(),
y_desc.get(),
args[4].get(),
1,
&algo_count,
&perf,
args[3].data(),
args[3].get(),
args[3].get_shape().bytes(),
false);
miopenConvolutionForward(args[0].data(),
miopenConvolutionForward(args[0].get(),
&alpha,
x_desc,
args[1].data(),
w_desc,
args[2].data(),
x_desc.get(),
args[1].get(),
w_desc.get(),
args[2].get(),
cd.get(),
perf.fwd_algo,
&beta,
y_desc,
args[4].data(),
args[3].data(),
y_desc.get(),
args[4].get(),
args[3].get(),
args[3].get_shape().bytes());
return result;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment