"vscode:/vscode.git/clone" did not exist on "bc6513a2710731841740382bdadd4eadd37a8acb"
Unverified Commit 3446bea5 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Enforce op name for check_shapes class (#633)



* Enforce op name for check_shapes class

* Add test for scalar

* Formatting
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 65bd2b18
...@@ -25,8 +25,6 @@ struct check_shapes ...@@ -25,8 +25,6 @@ struct check_shapes
{ {
} }
check_shapes(const std::vector<shape>& s) : begin(s.data()), end(s.data() + s.size()) {}
template <class Op> template <class Op>
check_shapes(const std::vector<shape>& s, const Op& op) check_shapes(const std::vector<shape>& s, const Op& op)
: begin(s.data()), end(s.data() + s.size()), name(op.name()) : begin(s.data()), end(s.data() + s.size()), name(op.name())
...@@ -59,6 +57,13 @@ struct check_shapes ...@@ -59,6 +57,13 @@ struct check_shapes
return *this; return *this;
} }
const check_shapes& nelements(std::size_t n) const
{
if(!this->all_of([&](const shape& s) { return s.elements() == n; }))
MIGRAPHX_THROW(prefix() + "Shapes must have only " + std::to_string(n) + " elements");
return *this;
}
const check_shapes& only_dims(std::size_t n) const const check_shapes& only_dims(std::size_t n) const
{ {
assert(begin != nullptr); assert(begin != nullptr);
......
...@@ -18,7 +18,7 @@ struct binary : op_name<Derived> ...@@ -18,7 +18,7 @@ struct binary : op_name<Derived>
value attributes() const { return base_attributes(); } value attributes() const { return base_attributes(); }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(2).same_type().same_dims(); check_shapes{inputs, static_cast<const Derived&>(*this)}.has(2).same_type().same_dims();
auto s0 = inputs.at(0); auto s0 = inputs.at(0);
auto s1 = inputs.at(1); auto s1 = inputs.at(1);
if(s0 == s1 and s0.packed()) if(s0 == s1 and s0.packed())
......
...@@ -23,7 +23,7 @@ struct clip ...@@ -23,7 +23,7 @@ struct clip
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(3).same_type(); check_shapes{inputs, *this}.has(3).same_type();
return inputs.front(); return inputs.front();
} }
......
...@@ -28,7 +28,7 @@ struct flatten ...@@ -28,7 +28,7 @@ struct flatten
std::string name() const { return "flatten"; } std::string name() const { return "flatten"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1); check_shapes{inputs, *this}.has(1);
auto&& lens = inputs.front().lens(); auto&& lens = inputs.front().lens();
int64_t n_dim = static_cast<int64_t>(lens.size()); int64_t n_dim = static_cast<int64_t>(lens.size());
if(axis > n_dim or axis < -n_dim) if(axis > n_dim or axis < -n_dim)
......
...@@ -27,7 +27,7 @@ struct load ...@@ -27,7 +27,7 @@ struct load
std::string name() const { return "load"; } std::string name() const { return "load"; }
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs}.has(1); check_shapes{inputs, *this}.has(1);
return s; return s;
} }
argument compute(const shape&, const std::vector<argument>& args) const argument compute(const shape&, const std::vector<argument>& args) const
......
...@@ -21,7 +21,7 @@ struct logsoftmax ...@@ -21,7 +21,7 @@ struct logsoftmax
std::string name() const { return "logsoftmax"; } std::string name() const { return "logsoftmax"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1).standard(); check_shapes{inputs, *this}.has(1).standard();
int64_t n_dim = static_cast<int64_t>(inputs[0].lens().size()); int64_t n_dim = static_cast<int64_t>(inputs[0].lens().size());
if(axis < -n_dim || axis >= n_dim) if(axis < -n_dim || axis >= n_dim)
{ {
......
...@@ -29,7 +29,7 @@ struct scalar ...@@ -29,7 +29,7 @@ struct scalar
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
assert(check_shapes{inputs}.has(1).only_dims(1).size() == 1); check_shapes{inputs, *this}.has(1).only_dims(1).nelements(1);
auto t = inputs.at(0).type(); auto t = inputs.at(0).type();
std::vector<std::size_t> strides(scalar_bcast_lens.size(), 0); std::vector<std::size_t> strides(scalar_bcast_lens.size(), 0);
return {t, scalar_bcast_lens, strides}; return {t, scalar_bcast_lens, strides};
......
...@@ -21,7 +21,7 @@ struct softmax ...@@ -21,7 +21,7 @@ struct softmax
std::string name() const { return "softmax"; } std::string name() const { return "softmax"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1).standard(); check_shapes{inputs, *this}.has(1).standard();
int64_t n_dim = inputs[0].lens().size(); int64_t n_dim = inputs[0].lens().size();
if(axis < -n_dim || axis >= n_dim) if(axis < -n_dim || axis >= n_dim)
{ {
......
...@@ -18,7 +18,7 @@ struct unary : op_name<Derived> ...@@ -18,7 +18,7 @@ struct unary : op_name<Derived>
value attributes() const { return base_attributes(); } value attributes() const { return base_attributes(); }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1); check_shapes{inputs, static_cast<const Derived&>(*this)}.has(1);
auto s = inputs.at(0); auto s = inputs.at(0);
if(s.packed()) if(s.packed())
{ {
......
...@@ -605,7 +605,7 @@ struct cpu_gemm ...@@ -605,7 +605,7 @@ struct cpu_gemm
if(inputs.size() == 3) if(inputs.size() == 3)
{ {
auto c_shape = inputs.at(2); auto c_shape = inputs.at(2);
check_shapes{{c_shape}}.not_broadcasted(); check_shapes{{c_shape}, *this}.not_broadcasted();
} }
return op.compute_shape(inputs); return op.compute_shape(inputs);
} }
...@@ -658,7 +658,7 @@ struct cpu_quant_gemm ...@@ -658,7 +658,7 @@ struct cpu_quant_gemm
if(inputs.size() == 3) if(inputs.size() == 3)
{ {
auto c_shape = inputs.at(2); auto c_shape = inputs.at(2);
check_shapes{{c_shape}}.not_broadcasted(); check_shapes{{c_shape}, *this}.not_broadcasted();
} }
return op.compute_shape(inputs); return op.compute_shape(inputs);
} }
...@@ -751,7 +751,7 @@ struct cpu_unary : auto_register_op<cpu_unary<Op>> ...@@ -751,7 +751,7 @@ struct cpu_unary : auto_register_op<cpu_unary<Op>>
std::string name() const { return op.name(); } std::string name() const { return op.name(); }
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs}.has(1); check_shapes{inputs, *this}.has(1);
auto s = inputs.at(0); auto s = inputs.at(0);
return {s.type(), s.lens()}; return {s.type(), s.lens()};
} }
......
...@@ -9,7 +9,7 @@ namespace gpu { ...@@ -9,7 +9,7 @@ namespace gpu {
shape hip_convert::compute_shape(std::vector<shape> inputs) const shape hip_convert::compute_shape(std::vector<shape> inputs) const
{ {
inputs.pop_back(); inputs.pop_back();
check_shapes{inputs}.packed(); check_shapes{inputs, *this}.packed();
return op.compute_shape(inputs); return op.compute_shape(inputs);
} }
......
...@@ -39,7 +39,7 @@ struct rocblas_gemm ...@@ -39,7 +39,7 @@ struct rocblas_gemm
{ {
std::vector<shape> in_shapes(inputs); std::vector<shape> in_shapes(inputs);
in_shapes.pop_back(); in_shapes.pop_back();
check_shapes{in_shapes}.not_broadcasted(); check_shapes{in_shapes, *this}.not_broadcasted();
batch_not_transposed(inputs[0].strides()); batch_not_transposed(inputs[0].strides());
batch_not_transposed(inputs[1].strides()); batch_not_transposed(inputs[1].strides());
......
...@@ -46,7 +46,7 @@ struct hip_allocate ...@@ -46,7 +46,7 @@ struct hip_allocate
std::string name() const { return "hip::allocate"; } std::string name() const { return "hip::allocate"; }
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs}.has(0); check_shapes{inputs, *this}.has(0);
return s; return s;
} }
argument compute(context&, const shape& output_shape, const std::vector<argument>&) const argument compute(context&, const shape& output_shape, const std::vector<argument>&) const
...@@ -80,7 +80,7 @@ struct hip_copy_to_gpu ...@@ -80,7 +80,7 @@ struct hip_copy_to_gpu
std::string name() const { return "hip::copy_to_gpu"; } std::string name() const { return "hip::copy_to_gpu"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1, 2); check_shapes{inputs, *this}.has(1, 2);
return inputs.at(0); return inputs.at(0);
} }
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
...@@ -106,7 +106,7 @@ struct hip_copy_from_gpu ...@@ -106,7 +106,7 @@ struct hip_copy_from_gpu
std::string name() const { return "hip::copy_from_gpu"; } std::string name() const { return "hip::copy_from_gpu"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1, 2); check_shapes{inputs, *this}.has(1, 2);
return inputs.at(0); return inputs.at(0);
} }
argument argument
...@@ -135,7 +135,7 @@ struct hip_copy ...@@ -135,7 +135,7 @@ struct hip_copy
std::string name() const { return "hip::copy"; } std::string name() const { return "hip::copy"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(2).standard(); check_shapes{inputs, *this}.has(2).standard();
return inputs.at(1); return inputs.at(1);
} }
argument compute(context& ctx, const shape&, std::vector<argument> args) const argument compute(context& ctx, const shape&, std::vector<argument> args) const
...@@ -162,7 +162,7 @@ struct hip_allocate_memory ...@@ -162,7 +162,7 @@ struct hip_allocate_memory
std::string name() const { return "hip::hip_allocate_memory"; } std::string name() const { return "hip::hip_allocate_memory"; }
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs}.has(0); check_shapes{inputs, *this}.has(0);
return s; return s;
} }
...@@ -192,7 +192,7 @@ struct hip_copy_literal ...@@ -192,7 +192,7 @@ struct hip_copy_literal
std::string name() const { return "hip::hip_copy_literal"; } std::string name() const { return "hip::hip_copy_literal"; }
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs}.has(0); check_shapes{inputs, *this}.has(0);
return l.get_shape(); return l.get_shape();
} }
......
...@@ -33,7 +33,7 @@ struct reduce_op : oper<Derived> ...@@ -33,7 +33,7 @@ struct reduce_op : oper<Derived>
{ {
std::vector<shape> in_shapes{inputs}; std::vector<shape> in_shapes{inputs};
in_shapes.pop_back(); in_shapes.pop_back();
check_shapes{in_shapes}.standard(); check_shapes{in_shapes, *this}.standard();
return op.compute_shape(in_shapes); return op.compute_shape(in_shapes);
} }
......
...@@ -25,7 +25,7 @@ struct allocate ...@@ -25,7 +25,7 @@ struct allocate
std::string name() const { return "allocate"; } std::string name() const { return "allocate"; }
migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
{ {
migraphx::check_shapes{inputs}.has(0); migraphx::check_shapes{inputs, *this}.has(0);
return s; return s;
} }
migraphx::argument compute(migraphx::context&, migraphx::argument compute(migraphx::context&,
......
...@@ -65,7 +65,7 @@ struct allocate ...@@ -65,7 +65,7 @@ struct allocate
std::string name() const { return "allocate"; } std::string name() const { return "allocate"; }
migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
{ {
migraphx::check_shapes{inputs}.has(0); migraphx::check_shapes{inputs, *this}.has(0);
return s; return s;
} }
migraphx::argument compute(migraphx::context&, migraphx::argument compute(migraphx::context&,
...@@ -81,7 +81,7 @@ struct simple_op ...@@ -81,7 +81,7 @@ struct simple_op
std::string name() const { return "simple_op"; } std::string name() const { return "simple_op"; }
migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
{ {
migraphx::check_shapes{inputs}.has(1); migraphx::check_shapes{inputs, *this}.has(1);
return inputs.at(0); return inputs.at(0);
} }
migraphx::argument compute(migraphx::context&, migraphx::argument compute(migraphx::context&,
......
...@@ -558,6 +558,19 @@ TEST_CASE(test_argmin) ...@@ -558,6 +558,19 @@ TEST_CASE(test_argmin)
} }
} }
TEST_CASE(test_scalar)
{
migraphx::shape s1{migraphx::shape::float_type, {1}, {1}};
migraphx::shape s2{migraphx::shape::float_type, {2, 3, 4, 5}, {0, 0, 0, 0}};
expect_shape(s2, migraphx::op::scalar{{2, 3, 4, 5}}, s1);
}
TEST_CASE(test_scalar_nelemnts)
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(migraphx::op::scalar{{2, 3, 4, 5}}, input);
}
TEST_CASE(test_squeeze) TEST_CASE(test_squeeze)
{ {
migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}}; migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment