Commit ab855a89 authored by Paul's avatar Paul
Browse files

Enable nhwc convolution

parent 7b447623
......@@ -56,6 +56,8 @@ static void create_pointwise_modules(module_pass_manager& mpm)
{
if(not ins->get_operator().attributes().get("pointwise", false))
continue;
if (ins->get_operator().name() == "layout")
continue;
assert(ins->get_operator().attributes().contains("point_op"));
auto* pm = mpm.create_module(mpm.get_module().name() + ":pointwise" + std::to_string(n++));
pm->set_bypass();
......
......@@ -24,6 +24,7 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_CHECK_SHAPES_HPP
#define MIGRAPHX_GUARD_RTGLIB_CHECK_SHAPES_HPP
#include <migraphx/permutation.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
......@@ -232,6 +233,18 @@ struct check_shapes
return *this;
}
/*!
* Check all shapes are packed with certain layouts
*/
const check_shapes& packed_layouts(const std::initializer_list<std::vector<int64_t>>& layouts) const
{
if(not this->all_of([&](const shape& s) {
return s.packed() and contains(layouts, find_permutation(s));
}))
MIGRAPHX_THROW(prefix() + "Shapes are not packed with correct layout");
return *this;
}
/*!
* Check all shapes are packed or broadcasted.
*/
......
......@@ -31,9 +31,9 @@ namespace gpu {
shape miopen_convolution::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(4).standard();
check_shapes{inputs, *this}.has(4);
std::vector<shape> conv_inputs(inputs.begin(), inputs.begin() + 2);
check_shapes{conv_inputs, *this}.max_ndims(5);
check_shapes{conv_inputs, *this}.max_ndims(5).packed_layouts({{0, 1, 2}, {0, 1, 2, 3}, {0, 2, 3, 1}, {0, 1, 2, 3, 4}});
return op.normalize_compute_shape(conv_inputs);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment