"...git@developer.sourcefind.cn:modelzoo/qwen_lmdeploy.git" did not exist on "6c7d99928251e03249ac2c65006c7452f5676bb7"
Commit 1b18e9d0 authored by Paul's avatar Paul
Browse files

Return an argument instead of a literal from eval

parent 8ff0905c
......@@ -62,7 +62,7 @@ struct program
shape get_parameter_shape(std::string name);
literal eval(std::unordered_map<std::string, argument> params) const;
argument eval(std::unordered_map<std::string, argument> params) const;
friend std::ostream& operator<<(std::ostream& os, const program& p);
......
......@@ -6,6 +6,11 @@
namespace rtg {
#define RTG_REQUIRES(...) class=typename std::enable_if<(__VA_ARGS__)>::type
struct raw_data_base
{};
/**
* @brief Provides a base class for common operations with raw buffer
*
......@@ -15,29 +20,8 @@ namespace rtg {
*
*/
template <class Derived>
struct raw_data
struct raw_data : raw_data_base
{
friend bool operator==(const Derived& x, const Derived& y)
{
auto&& xshape = x.get_shape();
auto&& yshape = y.get_shape();
bool result = x.empty() && y.empty();
if(not result && xshape == yshape)
{
auto&& xbuffer = x.data();
auto&& ybuffer = y.data();
// TODO: Dont use tensor view for single values
xshape.visit_type([&](auto as) {
auto xview = make_view(xshape, as.from(xbuffer));
auto yview = make_view(yshape, as.from(ybuffer));
result = xview == yview;
});
}
return result;
}
friend bool operator!=(const Derived& x, const Derived& y) { return !(x == y); }
template <class Stream>
friend Stream& operator<<(Stream& os, const Derived& d)
{
......@@ -114,6 +98,32 @@ struct raw_data
auto_cast get() const { return {static_cast<const Derived*>(this)}; }
};
template<class T, class U, RTG_REQUIRES(std::is_base_of<raw_data_base, T>{}), RTG_REQUIRES(std::is_base_of<raw_data_base, U>{})>
bool operator==(const T& x, const U& y)
{
auto&& xshape = x.get_shape();
auto&& yshape = y.get_shape();
bool result = x.empty() && y.empty();
if(not result && xshape == yshape)
{
auto&& xbuffer = x.data();
auto&& ybuffer = y.data();
// TODO: Dont use tensor view for single values
xshape.visit_type([&](auto as) {
auto xview = make_view(xshape, as.from(xbuffer));
auto yview = make_view(yshape, as.from(ybuffer));
result = xview == yview;
});
}
return result;
}
template<class T, class U, RTG_REQUIRES(std::is_base_of<raw_data_base, T>{}), RTG_REQUIRES(std::is_base_of<raw_data_base, U>{})>
bool operator!=(const T& x, const U& y)
{
return !(x == y);
}
namespace detail {
template <class V, class... Ts>
void visit_all_impl(const shape& s, V&& v, Ts&&... xs)
......
......@@ -103,22 +103,6 @@ struct tensor_view
return m_data + this->size();
}
friend bool operator==(const tensor_view<T>& x, const tensor_view<T>& y)
{
if(x.m_shape == y.m_shape)
{
for(std::size_t i = 0; i < x.m_shape.elements(); i++)
{
if(!float_equal(x[i], y[i]))
return false;
}
return true;
}
return false;
}
friend bool operator!=(const tensor_view<T>& x, const tensor_view<T>& y) { return !(x == y); }
friend std::ostream& operator<<(std::ostream& os, const tensor_view<T>& x)
{
if(!x.empty())
......@@ -137,6 +121,24 @@ struct tensor_view
shape m_shape;
};
template<class T, class U>
bool operator==(const tensor_view<T>& x, const tensor_view<U>& y)
{
if(x.get_shape() == y.get_shape())
{
for(std::size_t i = 0; i < x.get_shape().elements(); i++)
{
if(!float_equal(x[i], y[i]))
return false;
}
return true;
}
return false;
}
template<class T, class U>
bool operator!=(const tensor_view<T>& x, const tensor_view<U>& y) { return !(x == y); }
template <class T>
tensor_view<T> make_view(shape s, T* data)
{
......
......@@ -112,7 +112,7 @@ void program::compile(const target& t)
RTG_THROW("Invalid program from compilation");
}
literal program::eval(std::unordered_map<std::string, argument> params) const
argument program::eval(std::unordered_map<std::string, argument> params) const
{
assert(this->validate() != impl->instructions.end());
std::unordered_map<const instruction*, argument> results;
......@@ -142,7 +142,7 @@ literal program::eval(std::unordered_map<std::string, argument> params) const
}
results.emplace(std::addressof(ins), result);
}
return literal{result.get_shape(), result.data()};
return result;
}
std::ostream& operator<<(std::ostream& os, const program& p)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment