Unverified Commit 3446bea5 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Enforce op name for check_shapes class (#633)



* Enforce op name for check_shapes class

* Add test for scalar

* Formatting
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 65bd2b18
......@@ -25,8 +25,6 @@ struct check_shapes
{
}
check_shapes(const std::vector<shape>& s) : begin(s.data()), end(s.data() + s.size()) {}
template <class Op>
check_shapes(const std::vector<shape>& s, const Op& op)
: begin(s.data()), end(s.data() + s.size()), name(op.name())
......@@ -59,6 +57,13 @@ struct check_shapes
return *this;
}
const check_shapes& nelements(std::size_t n) const
{
if(!this->all_of([&](const shape& s) { return s.elements() == n; }))
MIGRAPHX_THROW(prefix() + "Shapes must have only " + std::to_string(n) + " elements");
return *this;
}
const check_shapes& only_dims(std::size_t n) const
{
assert(begin != nullptr);
......
......@@ -18,7 +18,7 @@ struct binary : op_name<Derived>
value attributes() const { return base_attributes(); }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(2).same_type().same_dims();
check_shapes{inputs, static_cast<const Derived&>(*this)}.has(2).same_type().same_dims();
auto s0 = inputs.at(0);
auto s1 = inputs.at(1);
if(s0 == s1 and s0.packed())
......
......@@ -23,7 +23,7 @@ struct clip
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(3).same_type();
check_shapes{inputs, *this}.has(3).same_type();
return inputs.front();
}
......
......@@ -28,7 +28,7 @@ struct flatten
std::string name() const { return "flatten"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
check_shapes{inputs, *this}.has(1);
auto&& lens = inputs.front().lens();
int64_t n_dim = static_cast<int64_t>(lens.size());
if(axis > n_dim or axis < -n_dim)
......
......@@ -27,7 +27,7 @@ struct load
std::string name() const { return "load"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs}.has(1);
check_shapes{inputs, *this}.has(1);
return s;
}
argument compute(const shape&, const std::vector<argument>& args) const
......
......@@ -21,7 +21,7 @@ struct logsoftmax
std::string name() const { return "logsoftmax"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1).standard();
check_shapes{inputs, *this}.has(1).standard();
int64_t n_dim = static_cast<int64_t>(inputs[0].lens().size());
if(axis < -n_dim || axis >= n_dim)
{
......
......@@ -29,7 +29,7 @@ struct scalar
shape compute_shape(std::vector<shape> inputs) const
{
assert(check_shapes{inputs}.has(1).only_dims(1).size() == 1);
check_shapes{inputs, *this}.has(1).only_dims(1).nelements(1);
auto t = inputs.at(0).type();
std::vector<std::size_t> strides(scalar_bcast_lens.size(), 0);
return {t, scalar_bcast_lens, strides};
......
......@@ -21,7 +21,7 @@ struct softmax
std::string name() const { return "softmax"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1).standard();
check_shapes{inputs, *this}.has(1).standard();
int64_t n_dim = inputs[0].lens().size();
if(axis < -n_dim || axis >= n_dim)
{
......
......@@ -18,7 +18,7 @@ struct unary : op_name<Derived>
value attributes() const { return base_attributes(); }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
check_shapes{inputs, static_cast<const Derived&>(*this)}.has(1);
auto s = inputs.at(0);
if(s.packed())
{
......
......@@ -605,7 +605,7 @@ struct cpu_gemm
if(inputs.size() == 3)
{
auto c_shape = inputs.at(2);
check_shapes{{c_shape}}.not_broadcasted();
check_shapes{{c_shape}, *this}.not_broadcasted();
}
return op.compute_shape(inputs);
}
......@@ -658,7 +658,7 @@ struct cpu_quant_gemm
if(inputs.size() == 3)
{
auto c_shape = inputs.at(2);
check_shapes{{c_shape}}.not_broadcasted();
check_shapes{{c_shape}, *this}.not_broadcasted();
}
return op.compute_shape(inputs);
}
......@@ -751,7 +751,7 @@ struct cpu_unary : auto_register_op<cpu_unary<Op>>
std::string name() const { return op.name(); }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs}.has(1);
check_shapes{inputs, *this}.has(1);
auto s = inputs.at(0);
return {s.type(), s.lens()};
}
......
......@@ -9,7 +9,7 @@ namespace gpu {
shape hip_convert::compute_shape(std::vector<shape> inputs) const
{
inputs.pop_back();
check_shapes{inputs}.packed();
check_shapes{inputs, *this}.packed();
return op.compute_shape(inputs);
}
......
......@@ -39,7 +39,7 @@ struct rocblas_gemm
{
std::vector<shape> in_shapes(inputs);
in_shapes.pop_back();
check_shapes{in_shapes}.not_broadcasted();
check_shapes{in_shapes, *this}.not_broadcasted();
batch_not_transposed(inputs[0].strides());
batch_not_transposed(inputs[1].strides());
......
......@@ -46,7 +46,7 @@ struct hip_allocate
std::string name() const { return "hip::allocate"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs}.has(0);
check_shapes{inputs, *this}.has(0);
return s;
}
argument compute(context&, const shape& output_shape, const std::vector<argument>&) const
......@@ -80,7 +80,7 @@ struct hip_copy_to_gpu
std::string name() const { return "hip::copy_to_gpu"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1, 2);
check_shapes{inputs, *this}.has(1, 2);
return inputs.at(0);
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
......@@ -106,7 +106,7 @@ struct hip_copy_from_gpu
std::string name() const { return "hip::copy_from_gpu"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1, 2);
check_shapes{inputs, *this}.has(1, 2);
return inputs.at(0);
}
argument
......@@ -135,7 +135,7 @@ struct hip_copy
std::string name() const { return "hip::copy"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(2).standard();
check_shapes{inputs, *this}.has(2).standard();
return inputs.at(1);
}
argument compute(context& ctx, const shape&, std::vector<argument> args) const
......@@ -162,7 +162,7 @@ struct hip_allocate_memory
std::string name() const { return "hip::hip_allocate_memory"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs}.has(0);
check_shapes{inputs, *this}.has(0);
return s;
}
......@@ -192,7 +192,7 @@ struct hip_copy_literal
std::string name() const { return "hip::hip_copy_literal"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs}.has(0);
check_shapes{inputs, *this}.has(0);
return l.get_shape();
}
......
......@@ -33,7 +33,7 @@ struct reduce_op : oper<Derived>
{
std::vector<shape> in_shapes{inputs};
in_shapes.pop_back();
check_shapes{in_shapes}.standard();
check_shapes{in_shapes, *this}.standard();
return op.compute_shape(in_shapes);
}
......
......@@ -25,7 +25,7 @@ struct allocate
std::string name() const { return "allocate"; }
migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
{
migraphx::check_shapes{inputs}.has(0);
migraphx::check_shapes{inputs, *this}.has(0);
return s;
}
migraphx::argument compute(migraphx::context&,
......
......@@ -65,7 +65,7 @@ struct allocate
std::string name() const { return "allocate"; }
migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
{
migraphx::check_shapes{inputs}.has(0);
migraphx::check_shapes{inputs, *this}.has(0);
return s;
}
migraphx::argument compute(migraphx::context&,
......@@ -81,7 +81,7 @@ struct simple_op
std::string name() const { return "simple_op"; }
migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
{
migraphx::check_shapes{inputs}.has(1);
migraphx::check_shapes{inputs, *this}.has(1);
return inputs.at(0);
}
migraphx::argument compute(migraphx::context&,
......
......@@ -558,6 +558,19 @@ TEST_CASE(test_argmin)
}
}
TEST_CASE(test_scalar)
{
migraphx::shape s1{migraphx::shape::float_type, {1}, {1}};
migraphx::shape s2{migraphx::shape::float_type, {2, 3, 4, 5}, {0, 0, 0, 0}};
expect_shape(s2, migraphx::op::scalar{{2, 3, 4, 5}}, s1);
}
TEST_CASE(test_scalar_nelemnts)
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(migraphx::op::scalar{{2, 3, 4, 5}}, input);
}
TEST_CASE(test_squeeze)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment