Unverified Commit 85b0563c authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Add with_type to shape class (#1102)

Add with_type to shape class
parent 40c087bd
......@@ -131,6 +131,8 @@ struct shape
shape with_lens(type_t t, const std::vector<std::size_t>& l) const;
shape with_lens(const std::vector<std::size_t>& l) const;
shape with_type(type_t t) const;
friend bool operator==(const shape& x, const shape& y);
friend bool operator!=(const shape& x, const shape& y);
friend std::ostream& operator<<(std::ostream& os, const shape& x);
......@@ -225,6 +227,7 @@ struct shape
const std::vector<shape>& sub_shapes() const;
private:
shape(std::shared_ptr<shape_impl> pimpl);
std::shared_ptr<const shape_impl> impl;
std::size_t element_space() const;
......
......@@ -86,6 +86,8 @@ struct shape_impl
return std::accumulate(
m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies<std::size_t>());
}
std::shared_ptr<shape_impl> copy() const { return std::make_shared<shape_impl>(*this); }
};
const std::vector<shape::type_t>& shape::types()
......@@ -135,6 +137,8 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
shape::shape(const std::vector<shape>& subs) : impl(std::make_shared<shape_impl>(subs)) {}
shape::shape(std::shared_ptr<shape_impl> pimpl) : impl(std::move(pimpl)) {}
shape shape::from_permutation(type_t t,
const std::vector<std::size_t>& l,
const std::vector<int64_t>& perm)
......@@ -294,6 +298,13 @@ shape shape::with_lens(const std::vector<std::size_t>& l) const
return this->with_lens(this->type(), l);
}
shape shape::with_type(type_t t) const
{
auto c = impl->copy();
c->m_type = t;
return {c};
}
std::size_t shape::element_space() const { return impl->element_space(); }
std::string shape::type_string() const { return name(this->type()); }
......
......@@ -608,4 +608,15 @@ TEST_CASE(cpp_type_name)
EXPECT(test::throws([&] { migraphx::shape::cpp_type(migraphx::shape::tuple_type); }));
}
TEST_CASE(test_with_type)
{
migraphx::shape s{migraphx::shape::float_type, {2, 2}, {1, 0}};
EXPECT(s.type() == migraphx::shape::float_type);
auto new_s = s.with_type(migraphx::shape::half_type);
EXPECT(s.type() == migraphx::shape::float_type);
EXPECT(s.type() != new_s.type());
EXPECT(s.lens() == new_s.lens());
EXPECT(s.strides() == new_s.strides());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment