"...composable_kernel_rocm.git" did not exist on "09f3a75ee8ada2a6d5d7c7f4bd32fffbba101438"
Commit 1b18e9d0 authored by Paul's avatar Paul
Browse files

Return an argument instead of a literal from eval

parent 8ff0905c
...@@ -62,7 +62,7 @@ struct program ...@@ -62,7 +62,7 @@ struct program
shape get_parameter_shape(std::string name); shape get_parameter_shape(std::string name);
literal eval(std::unordered_map<std::string, argument> params) const; argument eval(std::unordered_map<std::string, argument> params) const;
friend std::ostream& operator<<(std::ostream& os, const program& p); friend std::ostream& operator<<(std::ostream& os, const program& p);
......
...@@ -6,6 +6,11 @@ ...@@ -6,6 +6,11 @@
namespace rtg { namespace rtg {
#define RTG_REQUIRES(...) class=typename std::enable_if<(__VA_ARGS__)>::type
struct raw_data_base
{};
/** /**
* @brief Provides a base class for common operations with raw buffer * @brief Provides a base class for common operations with raw buffer
* *
...@@ -15,29 +20,8 @@ namespace rtg { ...@@ -15,29 +20,8 @@ namespace rtg {
* *
*/ */
template <class Derived> template <class Derived>
struct raw_data struct raw_data : raw_data_base
{ {
friend bool operator==(const Derived& x, const Derived& y)
{
auto&& xshape = x.get_shape();
auto&& yshape = y.get_shape();
bool result = x.empty() && y.empty();
if(not result && xshape == yshape)
{
auto&& xbuffer = x.data();
auto&& ybuffer = y.data();
// TODO: Dont use tensor view for single values
xshape.visit_type([&](auto as) {
auto xview = make_view(xshape, as.from(xbuffer));
auto yview = make_view(yshape, as.from(ybuffer));
result = xview == yview;
});
}
return result;
}
friend bool operator!=(const Derived& x, const Derived& y) { return !(x == y); }
template <class Stream> template <class Stream>
friend Stream& operator<<(Stream& os, const Derived& d) friend Stream& operator<<(Stream& os, const Derived& d)
{ {
...@@ -114,6 +98,32 @@ struct raw_data ...@@ -114,6 +98,32 @@ struct raw_data
auto_cast get() const { return {static_cast<const Derived*>(this)}; } auto_cast get() const { return {static_cast<const Derived*>(this)}; }
}; };
template<class T, class U, RTG_REQUIRES(std::is_base_of<raw_data_base, T>{}), RTG_REQUIRES(std::is_base_of<raw_data_base, U>{})>
bool operator==(const T& x, const U& y)
{
auto&& xshape = x.get_shape();
auto&& yshape = y.get_shape();
bool result = x.empty() && y.empty();
if(not result && xshape == yshape)
{
auto&& xbuffer = x.data();
auto&& ybuffer = y.data();
// TODO: Dont use tensor view for single values
xshape.visit_type([&](auto as) {
auto xview = make_view(xshape, as.from(xbuffer));
auto yview = make_view(yshape, as.from(ybuffer));
result = xview == yview;
});
}
return result;
}
template<class T, class U, RTG_REQUIRES(std::is_base_of<raw_data_base, T>{}), RTG_REQUIRES(std::is_base_of<raw_data_base, U>{})>
bool operator!=(const T& x, const U& y)
{
return !(x == y);
}
namespace detail { namespace detail {
template <class V, class... Ts> template <class V, class... Ts>
void visit_all_impl(const shape& s, V&& v, Ts&&... xs) void visit_all_impl(const shape& s, V&& v, Ts&&... xs)
......
...@@ -103,22 +103,6 @@ struct tensor_view ...@@ -103,22 +103,6 @@ struct tensor_view
return m_data + this->size(); return m_data + this->size();
} }
friend bool operator==(const tensor_view<T>& x, const tensor_view<T>& y)
{
if(x.m_shape == y.m_shape)
{
for(std::size_t i = 0; i < x.m_shape.elements(); i++)
{
if(!float_equal(x[i], y[i]))
return false;
}
return true;
}
return false;
}
friend bool operator!=(const tensor_view<T>& x, const tensor_view<T>& y) { return !(x == y); }
friend std::ostream& operator<<(std::ostream& os, const tensor_view<T>& x) friend std::ostream& operator<<(std::ostream& os, const tensor_view<T>& x)
{ {
if(!x.empty()) if(!x.empty())
...@@ -137,6 +121,24 @@ struct tensor_view ...@@ -137,6 +121,24 @@ struct tensor_view
shape m_shape; shape m_shape;
}; };
template<class T, class U>
bool operator==(const tensor_view<T>& x, const tensor_view<U>& y)
{
if(x.get_shape() == y.get_shape())
{
for(std::size_t i = 0; i < x.get_shape().elements(); i++)
{
if(!float_equal(x[i], y[i]))
return false;
}
return true;
}
return false;
}
template<class T, class U>
bool operator!=(const tensor_view<T>& x, const tensor_view<U>& y) { return !(x == y); }
template <class T> template <class T>
tensor_view<T> make_view(shape s, T* data) tensor_view<T> make_view(shape s, T* data)
{ {
......
...@@ -112,7 +112,7 @@ void program::compile(const target& t) ...@@ -112,7 +112,7 @@ void program::compile(const target& t)
RTG_THROW("Invalid program from compilation"); RTG_THROW("Invalid program from compilation");
} }
literal program::eval(std::unordered_map<std::string, argument> params) const argument program::eval(std::unordered_map<std::string, argument> params) const
{ {
assert(this->validate() != impl->instructions.end()); assert(this->validate() != impl->instructions.end());
std::unordered_map<const instruction*, argument> results; std::unordered_map<const instruction*, argument> results;
...@@ -142,7 +142,7 @@ literal program::eval(std::unordered_map<std::string, argument> params) const ...@@ -142,7 +142,7 @@ literal program::eval(std::unordered_map<std::string, argument> params) const
} }
results.emplace(std::addressof(ins), result); results.emplace(std::addressof(ins), result);
} }
return literal{result.get_shape(), result.data()}; return result;
} }
std::ostream& operator<<(std::ostream& os, const program& p) std::ostream& operator<<(std::ostream& os, const program& p)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment