Commit cf296f0c authored by Paul's avatar Paul
Browse files

Format

parent ab855a89
...@@ -56,7 +56,7 @@ static void create_pointwise_modules(module_pass_manager& mpm) ...@@ -56,7 +56,7 @@ static void create_pointwise_modules(module_pass_manager& mpm)
{ {
if(not ins->get_operator().attributes().get("pointwise", false)) if(not ins->get_operator().attributes().get("pointwise", false))
continue; continue;
if (ins->get_operator().name() == "layout") if(ins->get_operator().name() == "layout")
continue; continue;
assert(ins->get_operator().attributes().contains("point_op")); assert(ins->get_operator().attributes().contains("point_op"));
auto* pm = mpm.create_module(mpm.get_module().name() + ":pointwise" + std::to_string(n++)); auto* pm = mpm.create_module(mpm.get_module().name() + ":pointwise" + std::to_string(n++));
......
...@@ -236,7 +236,8 @@ struct check_shapes ...@@ -236,7 +236,8 @@ struct check_shapes
/*! /*!
* Check all shapes are packed with certain layouts * Check all shapes are packed with certain layouts
*/ */
const check_shapes& packed_layouts(const std::initializer_list<std::vector<int64_t>>& layouts) const const check_shapes&
packed_layouts(const std::initializer_list<std::vector<int64_t>>& layouts) const
{ {
if(not this->all_of([&](const shape& s) { if(not this->all_of([&](const shape& s) {
return s.packed() and contains(layouts, find_permutation(s)); return s.packed() and contains(layouts, find_permutation(s));
......
...@@ -33,7 +33,8 @@ shape miopen_convolution::compute_shape(const std::vector<shape>& inputs) const ...@@ -33,7 +33,8 @@ shape miopen_convolution::compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(4); check_shapes{inputs, *this}.has(4);
std::vector<shape> conv_inputs(inputs.begin(), inputs.begin() + 2); std::vector<shape> conv_inputs(inputs.begin(), inputs.begin() + 2);
check_shapes{conv_inputs, *this}.max_ndims(5).packed_layouts({{0, 1, 2}, {0, 1, 2, 3}, {0, 2, 3, 1}, {0, 1, 2, 3, 4}}); check_shapes{conv_inputs, *this}.max_ndims(5).packed_layouts(
{{0, 1, 2}, {0, 1, 2, 3}, {0, 2, 3, 1}, {0, 1, 2, 3, 4}});
return op.normalize_compute_shape(conv_inputs); return op.normalize_compute_shape(conv_inputs);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment