Commit ecc1a605 authored by Paul's avatar Paul
Browse files

Add assert for cast

parent 17f5d17e
......@@ -39,12 +39,6 @@ struct argument : raw_data<argument>
const shape& get_shape() const { return this->m_shape; }
template <class T>
T* cast() const
{
return reinterpret_cast<T*>(this->data());
}
private:
shape m_shape;
};
......
......@@ -127,6 +127,17 @@ struct raw_data : raw_data_base
MIGRAPH_THROW("Incorrect data type for raw data");
return make_view(s, reinterpret_cast<T*>(buffer));
}
/// Cast the data pointer
template <class T>
T* cast() const
{
auto&& s = static_cast<const Derived&>(*this).get_shape();
auto&& buffer = static_cast<const Derived&>(*this).data();
assert(s.type() == migraph::shape::get_type<T>{});
return reinterpret_cast<T*>(buffer);
}
};
template <class T,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment