Unverified Commit b75c83d8 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Use jit for contiguous operator (#1217)

* Jit contiguous
parent 8c35fa94
...@@ -210,10 +210,10 @@ void replace(Range&& r, const T& old, const T& new_x) ...@@ -210,10 +210,10 @@ void replace(Range&& r, const T& old, const T& new_x)
std::replace(r.begin(), r.end(), old, new_x); std::replace(r.begin(), r.end(), old, new_x);
} }
template <class R1, class R2> template <class R1, class R2, class... Predicate>
bool equal(R1&& r1, R2&& r2) bool equal(R1&& r1, R2&& r2, Predicate... pred)
{ {
return std::equal(r1.begin(), r1.end(), r2.begin(), r2.end()); return std::equal(r1.begin(), r1.end(), r2.begin(), r2.end(), pred...);
} }
template <class R> template <class R>
......
...@@ -61,9 +61,7 @@ struct shape_impl ...@@ -61,9 +61,7 @@ struct shape_impl
{ {
assert(t != shape::tuple_type); assert(t != shape::tuple_type);
assert(m_lens.size() == m_strides.size()); assert(m_lens.size() == m_strides.size());
// assert(std::any_of(m_strides.begin(), m_strides.end(), [](auto x) { return x > 0; }) and m_standard = this->elements() == this->element_space() and not skips() and
// "At least one stride must be non-zero");
m_standard = this->elements() == this->element_space() and
std::is_sorted(m_strides.rbegin(), m_strides.rend()); std::is_sorted(m_strides.rbegin(), m_strides.rend());
} }
...@@ -110,6 +108,15 @@ struct shape_impl ...@@ -110,6 +108,15 @@ struct shape_impl
m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies<std::size_t>()); m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies<std::size_t>());
} }
// Does the shape skip over elements?
bool skips() const
{
assert(m_lens.size() == m_strides.size());
if(elements() == 1)
return false;
return std::none_of(m_strides.begin(), m_strides.end(), [](auto x) { return x == 1; });
}
std::shared_ptr<shape_impl> copy() const { return std::make_shared<shape_impl>(*this); } std::shared_ptr<shape_impl> copy() const { return std::make_shared<shape_impl>(*this); }
}; };
...@@ -260,7 +267,8 @@ void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end ...@@ -260,7 +267,8 @@ void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end
bool shape::packed() const bool shape::packed() const
{ {
return this->sub_shapes().empty() and this->elements() == this->element_space(); return this->sub_shapes().empty() and not impl->skips() and
this->elements() == this->element_space();
} }
bool shape::transposed() const bool shape::transposed() const
...@@ -285,10 +293,8 @@ bool shape::transposed() const ...@@ -285,10 +293,8 @@ bool shape::transposed() const
bool shape::broadcasted() const bool shape::broadcasted() const
{ {
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
return std::accumulate(this->strides().begin(), return std::any_of(
this->strides().end(), this->strides().begin(), this->strides().end(), [](auto x) { return x == 0; });
std::size_t{1},
std::multiplies<std::size_t>()) == 0;
} }
bool shape::scalar() const bool shape::scalar() const
......
...@@ -48,6 +48,7 @@ ...@@ -48,6 +48,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/register_op.hpp> #include <migraphx/register_op.hpp>
#include <migraphx/array.hpp> #include <migraphx/array.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/clip.hpp> #include <migraphx/op/clip.hpp>
#include <cmath> #include <cmath>
#include <set> #include <set>
...@@ -1012,9 +1013,43 @@ struct find_commutative_broadcast ...@@ -1012,9 +1013,43 @@ struct find_commutative_broadcast
} }
}; };
struct find_contiguous
{
auto matcher() const { return match::name("gpu::contiguous"); }
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
m.replace_instruction(
ins,
make_op("gpu::precompile_op", {{"op", to_value(make_op("contiguous"))}}),
ins->inputs());
}
};
struct find_contiguous_pointwise
{
auto matcher() const
{
return match::name("gpu::contiguous")(match::arg(0)(precompile_name("pointwise")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto pw = ins->inputs().front();
auto alloc = ins->inputs().back();
auto args = pw->inputs();
args.back() = alloc;
m.replace_instruction(ins, pw->get_operator(), args, pw->module_inputs());
}
};
void fuse_ops::apply(module& m) const void fuse_ops::apply(module& m) const
{ {
match::find_matches(m, find_gelu{}, find_gelu_new{fast_math}); match::find_matches(m, find_contiguous_pointwise{}, find_gelu{}, find_gelu_new{fast_math});
run_passes(m, {dead_code_elimination{}}); run_passes(m, {dead_code_elimination{}});
match::find_matches(m, find_triadd{}); match::find_matches(m, find_triadd{});
match::find_matches(m, match::find_matches(m,
...@@ -1036,6 +1071,7 @@ void fuse_ops::apply(module& m) const ...@@ -1036,6 +1071,7 @@ void fuse_ops::apply(module& m) const
find_gemm_add{}, find_gemm_add{},
find_gemm_pointwise{}, find_gemm_pointwise{},
find_commutative_broadcast{}); find_commutative_broadcast{});
match::find_matches(m, find_contiguous{});
} }
} // namespace gpu } // namespace gpu
......
...@@ -79,7 +79,7 @@ static std::vector<std::string> get_op_names(const module& m) ...@@ -79,7 +79,7 @@ static std::vector<std::string> get_op_names(const module& m)
struct pointwise_compiler : compiler<pointwise_compiler> struct pointwise_compiler : compiler<pointwise_compiler>
{ {
std::vector<std::string> names() const { return {"pointwise"}; } std::vector<std::string> names() const { return {"pointwise", "contiguous"}; }
static std::size_t oversubscribe_if(bool b) static std::size_t oversubscribe_if(bool b)
{ {
...@@ -114,7 +114,16 @@ struct pointwise_compiler : compiler<pointwise_compiler> ...@@ -114,7 +114,16 @@ struct pointwise_compiler : compiler<pointwise_compiler>
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
compiler_replace compile(context& ctx, instruction_ref ins, const operation&) const compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
if(op.name() == "contiguous")
{
return replace(compile_op(
ctx,
to_shapes(ins->inputs()),
{{"lambda", "[](auto x) { return x; }"}, {"kernel", "contiguous_kernel"}}));
}
else
{ {
assert(not ins->module_inputs().empty()); assert(not ins->module_inputs().empty());
auto* pm = ins->module_inputs().front(); auto* pm = ins->module_inputs().front();
...@@ -130,19 +139,21 @@ struct pointwise_compiler : compiler<pointwise_compiler> ...@@ -130,19 +139,21 @@ struct pointwise_compiler : compiler<pointwise_compiler>
g.add_point_op("greater", "migraphx::abs(${0} > ${1})"); g.add_point_op("greater", "migraphx::abs(${0} > ${1})");
g.add_point_op("not", "migraphx::abs(not ${0})"); g.add_point_op("not", "migraphx::abs(not ${0})");
// Add explict conversions // Add explict conversions
g.fresult( g.fresult([](const shape& s) {
[](const shape& s) { return "migraphx::convert<" + shape::cpp_type(s.type()) + ">"; }); return "migraphx::convert<" + shape::cpp_type(s.type()) + ">";
});
auto name = g.create_function( auto name = g.create_function(
g.generate_module(*pm).set_attributes({"__device__"}).set_generic_types(*pm)); g.generate_module(*pm).set_attributes({"__device__"}).set_generic_types(*pm));
std::string lambda = "MIGRAPHX_LIFT(" + name + ")"; std::string lambda = "MIGRAPHX_LIFT(" + name + ")";
auto op_names = get_op_names(*pm); auto op_names = get_op_names(*pm);
op_names.push_back("kernel"); op_names.push_back("kernel");
auto op_name_string = join_strings(op_names, "_"); auto op_name_string = join_strings(op_names, "_");
return replace( return replace(compile_op(
compile_op(ctx, ctx,
to_shapes(ins->inputs()), to_shapes(ins->inputs()),
{{"lambda", lambda}, {"preamble", g.str()}, {"kernel", op_name_string}})); {{"lambda", lambda}, {"preamble", g.str()}, {"kernel", op_name_string}}));
} }
}
}; };
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -49,7 +49,7 @@ constexpr T accumulate(InputIt first, InputIt last, T init, BinaryOperation op) ...@@ -49,7 +49,7 @@ constexpr T accumulate(InputIt first, InputIt last, T init, BinaryOperation op)
{ {
for(; first != last; ++first) for(; first != last; ++first)
{ {
init = op(std::move(init), *first); init = op(static_cast<T&&>(init), *first);
} }
return init; return init;
} }
...@@ -64,6 +64,20 @@ constexpr OutputIt copy(InputIt first, InputIt last, OutputIt d_first) ...@@ -64,6 +64,20 @@ constexpr OutputIt copy(InputIt first, InputIt last, OutputIt d_first)
return d_first; return d_first;
} }
template <class InputIt, class OutputIt, class UnaryPredicate>
constexpr OutputIt copy_if(InputIt first, InputIt last, OutputIt d_first, UnaryPredicate pred)
{
for(; first != last; ++first)
{
if(pred(*first))
{
*d_first = *first;
++d_first;
}
}
return d_first;
}
template <class Iterator, class Compare> template <class Iterator, class Compare>
constexpr Iterator is_sorted_until(Iterator first, Iterator last, Compare comp) constexpr Iterator is_sorted_until(Iterator first, Iterator last, Compare comp)
{ {
...@@ -115,6 +129,24 @@ constexpr Iterator find(Iterator first, Iterator last, const T& value) ...@@ -115,6 +129,24 @@ constexpr Iterator find(Iterator first, Iterator last, const T& value)
return find_if(first, last, [&](const auto& x) { return x == value; }); return find_if(first, last, [&](const auto& x) { return x == value; });
} }
template <class InputIt, class UnaryPredicate>
constexpr bool any_of(InputIt first, InputIt last, UnaryPredicate p)
{
return find_if(first, last, p) != last;
}
template <class InputIt, class UnaryPredicate>
constexpr bool none_of(InputIt first, InputIt last, UnaryPredicate p)
{
return find_if(first, last, p) == last;
}
template <class InputIt, class UnaryPredicate>
constexpr bool all_of(InputIt first, InputIt last, UnaryPredicate p)
{
return none_of(first, last, [=](auto&& x) { return not p(x); });
}
template <class Iterator1, class Iterator2> template <class Iterator1, class Iterator2>
constexpr Iterator1 search(Iterator1 first, Iterator1 last, Iterator2 s_first, Iterator2 s_last) constexpr Iterator1 search(Iterator1 first, Iterator1 last, Iterator2 s_first, Iterator2 s_last)
{ {
......
...@@ -40,10 +40,17 @@ struct implicit_conversion_op ...@@ -40,10 +40,17 @@ struct implicit_conversion_op
template <index_int N, class U> template <index_int N, class U>
constexpr operator vec<U, N>() const constexpr operator vec<U, N>() const
{
if constexpr(vec_size<T>() == 0)
{
return x;
}
else
{ {
static_assert(vec_size<T>() == N, "Vector mismatch size"); static_assert(vec_size<T>() == N, "Vector mismatch size");
return __builtin_convertvector(x, vec<U, N>); return __builtin_convertvector(x, vec<U, N>);
} }
}
template <class U> template <class U>
constexpr operator U() const constexpr operator U() const
......
...@@ -44,7 +44,7 @@ struct shape ...@@ -44,7 +44,7 @@ struct shape
constexpr auto element_space() const { return _c<Strides{}.dot(Lens{} - 1) + 1>; } constexpr auto element_space() const { return _c<Strides{}.dot(Lens{} - 1) + 1>; }
constexpr auto packed() const { return elements() == element_space(); } constexpr auto packed() const { return not skips() and elements() == element_space(); }
constexpr auto broadcasted() const { return _c<Strides{}.product() == 0>; } constexpr auto broadcasted() const { return _c<Strides{}.product() == 0>; }
constexpr auto transposed() const constexpr auto transposed() const
{ {
...@@ -53,16 +53,9 @@ struct shape ...@@ -53,16 +53,9 @@ struct shape
if(shape{}.broadcasted()) if(shape{}.broadcasted())
{ {
index_array s{}; index_array s{};
index_int j = 0; auto out = copy_if(
for(index_int i = 0; i < s.size(); i++) lstrides.begin(), lstrides.end(), s.begin(), [](auto x) { return x != 0; });
{ return not is_sorted(s.begin(), out, greater{});
if(lstrides[i] != 0)
{
s[j] = lstrides[i];
j++;
}
}
return not is_sorted(s.begin(), s.begin() + j, greater{});
} }
else else
{ {
...@@ -70,6 +63,13 @@ struct shape ...@@ -70,6 +63,13 @@ struct shape
} }
}); });
} }
constexpr auto skips() const
{
return return_c([] {
auto lstrides = Strides{};
return none_of(lstrides.begin(), lstrides.end(), [](auto x) { return x == 1; });
});
}
constexpr auto standard() const { return packed() and not transposed(); } constexpr auto standard() const { return packed() and not transposed(); }
...@@ -86,8 +86,17 @@ struct shape ...@@ -86,8 +86,17 @@ struct shape
constexpr index_int index(index_int i) const constexpr index_int index(index_int i) const
{ {
if(this->standard()) if(this->standard())
{
MIGRAPHX_ASSERT(i == compute_index(i));
return i; return i;
}
else else
{
return compute_index(i);
}
}
constexpr index_int compute_index(index_int i) const
{ {
const auto rank = this->lens.size(); const auto rank = this->lens.size();
index_int s = 1; index_int s = 1;
...@@ -104,7 +113,6 @@ struct shape ...@@ -104,7 +113,6 @@ struct shape
} }
return result; return result;
} }
}
/// Convert single index into a multi-index /// Convert single index into a multi-index
constexpr index_array multi(index_int idx) const constexpr index_array multi(index_int idx) const
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
*/ */
#include <migraphx/reduce_dims.hpp> #include <migraphx/reduce_dims.hpp>
#include <migraphx/permutation.hpp> #include <migraphx/permutation.hpp>
#include <migraphx/ranges.hpp>
#include "test.hpp" #include "test.hpp"
migraphx::shape make_shape(std::vector<std::size_t> lens) migraphx::shape make_shape(std::vector<std::size_t> lens)
...@@ -35,6 +36,21 @@ migraphx::shape make_shape(std::vector<std::size_t> lens, std::vector<std::size_ ...@@ -35,6 +36,21 @@ migraphx::shape make_shape(std::vector<std::size_t> lens, std::vector<std::size_
return {migraphx::shape::float_type, std::move(lens), std::move(strides)}; return {migraphx::shape::float_type, std::move(lens), std::move(strides)};
} }
bool verify_shape(const migraphx::shape& s1, const migraphx::shape& s2)
{
if(s1.elements() != s2.elements())
return false;
return migraphx::all_of(migraphx::range(s1.elements()),
[&](auto i) { return s1.index(i) == s2.index(i); });
}
template <class Range1, class Range2>
bool verify_shapes(const Range1& r1, const Range2& r2)
{
return migraphx::equal(
r1, r2, [](const auto& s1, const auto& s2) { return verify_shape(s1, s2); });
}
TEST_CASE(same_standard) TEST_CASE(same_standard)
{ {
auto is = make_shape({64, 3, 7, 7}); auto is = make_shape({64, 3, 7, 7});
...@@ -42,7 +58,7 @@ TEST_CASE(same_standard) ...@@ -42,7 +58,7 @@ TEST_CASE(same_standard)
std::vector<migraphx::shape> ishapes = {is, is, is}; std::vector<migraphx::shape> ishapes = {is, is, is};
std::vector<migraphx::shape> eshapes = {os, os, os}; std::vector<migraphx::shape> eshapes = {os, os, os};
auto rshapes = migraphx::reduce_dims(ishapes); auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(verify_shapes(ishapes, rshapes));
EXPECT(eshapes == rshapes); EXPECT(eshapes == rshapes);
} }
...@@ -53,7 +69,7 @@ TEST_CASE(same_broadcast1) ...@@ -53,7 +69,7 @@ TEST_CASE(same_broadcast1)
std::vector<migraphx::shape> ishapes = {is, make_shape({64, 3, 7, 7}, {0, 1, 0, 0}), is}; std::vector<migraphx::shape> ishapes = {is, make_shape({64, 3, 7, 7}, {0, 1, 0, 0}), is};
std::vector<migraphx::shape> eshapes = {os, make_shape({64, 3, 7 * 7}, {0, 1, 0}), os}; std::vector<migraphx::shape> eshapes = {os, make_shape({64, 3, 7 * 7}, {0, 1, 0}), os};
auto rshapes = migraphx::reduce_dims(ishapes); auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(verify_shapes(ishapes, rshapes));
EXPECT(eshapes == rshapes); EXPECT(eshapes == rshapes);
} }
...@@ -64,7 +80,7 @@ TEST_CASE(same_broadcast2) ...@@ -64,7 +80,7 @@ TEST_CASE(same_broadcast2)
std::vector<migraphx::shape> ishapes = {is, make_shape({64, 3, 8, 7, 7}, {0, 8, 1, 0, 0}), is}; std::vector<migraphx::shape> ishapes = {is, make_shape({64, 3, 8, 7, 7}, {0, 8, 1, 0, 0}), is};
std::vector<migraphx::shape> eshapes = {os, make_shape({64, 8 * 3, 7 * 7}, {0, 1, 0}), os}; std::vector<migraphx::shape> eshapes = {os, make_shape({64, 8 * 3, 7 * 7}, {0, 1, 0}), os};
auto rshapes = migraphx::reduce_dims(ishapes); auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(verify_shapes(ishapes, rshapes));
EXPECT(eshapes == rshapes); EXPECT(eshapes == rshapes);
} }
...@@ -75,7 +91,7 @@ TEST_CASE(same_transposed) ...@@ -75,7 +91,7 @@ TEST_CASE(same_transposed)
std::vector<migraphx::shape> ishapes = {is, migraphx::reorder_shape(is, {0, 1, 3, 2}), is}; std::vector<migraphx::shape> ishapes = {is, migraphx::reorder_shape(is, {0, 1, 3, 2}), is};
std::vector<migraphx::shape> eshapes = {os, migraphx::reorder_shape(os, {0, 2, 1}), os}; std::vector<migraphx::shape> eshapes = {os, migraphx::reorder_shape(os, {0, 2, 1}), os};
auto rshapes = migraphx::reduce_dims(ishapes); auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(verify_shapes(ishapes, rshapes));
EXPECT(eshapes == rshapes); EXPECT(eshapes == rshapes);
} }
...@@ -86,7 +102,7 @@ TEST_CASE(different_masked1) ...@@ -86,7 +102,7 @@ TEST_CASE(different_masked1)
std::vector<migraphx::shape> ishapes = {is, make_shape({1, 3, 1, 1}), is}; std::vector<migraphx::shape> ishapes = {is, make_shape({1, 3, 1, 1}), is};
std::vector<migraphx::shape> eshapes = {os, make_shape({1, 3, 1}), os}; std::vector<migraphx::shape> eshapes = {os, make_shape({1, 3, 1}), os};
auto rshapes = migraphx::reduce_dims(ishapes); auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(verify_shapes(ishapes, rshapes));
EXPECT(eshapes == rshapes); EXPECT(eshapes == rshapes);
} }
...@@ -98,7 +114,7 @@ TEST_CASE(different_masked2) ...@@ -98,7 +114,7 @@ TEST_CASE(different_masked2)
is, make_shape({1, 3, 1, 1}), make_shape({64, 1, 7, 7})}; is, make_shape({1, 3, 1, 1}), make_shape({64, 1, 7, 7})};
std::vector<migraphx::shape> eshapes = {os, make_shape({1, 3, 1}), make_shape({64, 1, 7 * 7})}; std::vector<migraphx::shape> eshapes = {os, make_shape({1, 3, 1}), make_shape({64, 1, 7 * 7})};
auto rshapes = migraphx::reduce_dims(ishapes); auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(verify_shapes(ishapes, rshapes));
EXPECT(eshapes == rshapes); EXPECT(eshapes == rshapes);
} }
...@@ -128,7 +144,7 @@ TEST_CASE(transposed1) ...@@ -128,7 +144,7 @@ TEST_CASE(transposed1)
std::vector<migraphx::shape> eshapes = { std::vector<migraphx::shape> eshapes = {
make_shape({8, 28, 4, 56 * 56}), make_shape({8, 28, 4, 56 * 56}, {351232, 3136, 87808, 1})}; make_shape({8, 28, 4, 56 * 56}), make_shape({8, 28, 4, 56 * 56}, {351232, 3136, 87808, 1})};
auto rshapes = migraphx::reduce_dims(ishapes); auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(verify_shapes(ishapes, rshapes));
EXPECT(eshapes == rshapes); EXPECT(eshapes == rshapes);
} }
...@@ -137,6 +153,7 @@ TEST_CASE(non_packed_empty1) ...@@ -137,6 +153,7 @@ TEST_CASE(non_packed_empty1)
std::vector<migraphx::shape> ishapes = {make_shape({1, 12}, {589824, 64})}; std::vector<migraphx::shape> ishapes = {make_shape({1, 12}, {589824, 64})};
std::vector<migraphx::shape> eshapes = {make_shape({12}, {64})}; std::vector<migraphx::shape> eshapes = {make_shape({12}, {64})};
auto rshapes = migraphx::reduce_dims(ishapes); auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(verify_shapes(ishapes, rshapes));
EXPECT(eshapes == rshapes); EXPECT(eshapes == rshapes);
} }
...@@ -145,6 +162,7 @@ TEST_CASE(non_packed_empty2) ...@@ -145,6 +162,7 @@ TEST_CASE(non_packed_empty2)
std::vector<migraphx::shape> ishapes = {make_shape({12, 1}, {64, 589824})}; std::vector<migraphx::shape> ishapes = {make_shape({12, 1}, {64, 589824})};
std::vector<migraphx::shape> eshapes = {make_shape({12}, {64})}; std::vector<migraphx::shape> eshapes = {make_shape({12}, {64})};
auto rshapes = migraphx::reduce_dims(ishapes); auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(verify_shapes(ishapes, rshapes));
EXPECT(eshapes == rshapes); EXPECT(eshapes == rshapes);
} }
...@@ -155,6 +173,16 @@ TEST_CASE(single_dim) ...@@ -155,6 +173,16 @@ TEST_CASE(single_dim)
EXPECT(ishapes == rshapes); EXPECT(ishapes == rshapes);
} }
TEST_CASE(step_broadcast_transpose)
{
std::vector<migraphx::shape> ishapes = {make_shape({1, 2, 2, 1}, {0, 0, 3, 6}),
make_shape({1, 2, 2, 1}, {4, 2, 1, 1})};
std::vector<migraphx::shape> eshapes = {make_shape({2, 2}, {0, 3}), make_shape({2, 2}, {2, 1})};
auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(verify_shapes(ishapes, rshapes));
EXPECT(eshapes == rshapes);
}
TEST_CASE(empty) TEST_CASE(empty)
{ {
auto rshapes = migraphx::reduce_dims({}); auto rshapes = migraphx::reduce_dims({});
......
...@@ -200,6 +200,15 @@ TEST_CASE(test_shape_broadcasted5) ...@@ -200,6 +200,15 @@ TEST_CASE(test_shape_broadcasted5)
EXPECT(s.broadcasted()); EXPECT(s.broadcasted());
} }
TEST_CASE(test_shape_step_broadcasted)
{
migraphx::shape s{migraphx::shape::float_type, {2, 2}, {0, 3}};
EXPECT(not s.standard());
EXPECT(not s.packed());
EXPECT(not s.transposed());
EXPECT(s.broadcasted());
}
TEST_CASE(test_shape_default_copy) TEST_CASE(test_shape_default_copy)
{ {
migraphx::shape s1{}; migraphx::shape s1{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment