Commit cff16121 authored by Paul's avatar Paul
Browse files

Merge branch 'cast-check'

parents 165d41f3 e5d5cc2a
...@@ -39,12 +39,6 @@ struct argument : raw_data<argument> ...@@ -39,12 +39,6 @@ struct argument : raw_data<argument>
const shape& get_shape() const { return this->m_shape; } const shape& get_shape() const { return this->m_shape; }
template <class T>
T* cast() const
{
return reinterpret_cast<T*>(this->data());
}
private: private:
shape m_shape; shape m_shape;
}; };
......
...@@ -127,6 +127,16 @@ struct raw_data : raw_data_base ...@@ -127,6 +127,16 @@ struct raw_data : raw_data_base
MIGRAPH_THROW("Incorrect data type for raw data"); MIGRAPH_THROW("Incorrect data type for raw data");
return make_view(s, reinterpret_cast<T*>(buffer)); return make_view(s, reinterpret_cast<T*>(buffer));
} }
/// Cast the data pointer
template <class T>
T* cast() const
{
auto&& s = static_cast<const Derived&>(*this).get_shape();
auto&& buffer = static_cast<const Derived&>(*this).data();
assert(s.type() == migraph::shape::get_type<T>{});
return reinterpret_cast<T*>(buffer);
}
}; };
template <class T, template <class T,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment