Commit ab855a89 authored by Paul's avatar Paul
Browse files

Enable nhwc convolution

parent 7b447623
...@@ -56,6 +56,8 @@ static void create_pointwise_modules(module_pass_manager& mpm) ...@@ -56,6 +56,8 @@ static void create_pointwise_modules(module_pass_manager& mpm)
{ {
if(not ins->get_operator().attributes().get("pointwise", false)) if(not ins->get_operator().attributes().get("pointwise", false))
continue; continue;
if (ins->get_operator().name() == "layout")
continue;
assert(ins->get_operator().attributes().contains("point_op")); assert(ins->get_operator().attributes().contains("point_op"));
auto* pm = mpm.create_module(mpm.get_module().name() + ":pointwise" + std::to_string(n++)); auto* pm = mpm.create_module(mpm.get_module().name() + ":pointwise" + std::to_string(n++));
pm->set_bypass(); pm->set_bypass();
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_CHECK_SHAPES_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_CHECK_SHAPES_HPP
#define MIGRAPHX_GUARD_RTGLIB_CHECK_SHAPES_HPP #define MIGRAPHX_GUARD_RTGLIB_CHECK_SHAPES_HPP
#include <migraphx/permutation.hpp>
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
...@@ -232,6 +233,18 @@ struct check_shapes ...@@ -232,6 +233,18 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check all shapes are packed with certain layouts
*/
const check_shapes& packed_layouts(const std::initializer_list<std::vector<int64_t>>& layouts) const
{
if(not this->all_of([&](const shape& s) {
return s.packed() and contains(layouts, find_permutation(s));
}))
MIGRAPHX_THROW(prefix() + "Shapes are not packed with correct layout");
return *this;
}
/*! /*!
* Check all shapes are packed or broadcasted. * Check all shapes are packed or broadcasted.
*/ */
......
...@@ -31,9 +31,9 @@ namespace gpu { ...@@ -31,9 +31,9 @@ namespace gpu {
shape miopen_convolution::compute_shape(const std::vector<shape>& inputs) const shape miopen_convolution::compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(4).standard(); check_shapes{inputs, *this}.has(4);
std::vector<shape> conv_inputs(inputs.begin(), inputs.begin() + 2); std::vector<shape> conv_inputs(inputs.begin(), inputs.begin() + 2);
check_shapes{conv_inputs, *this}.max_ndims(5); check_shapes{conv_inputs, *this}.max_ndims(5).packed_layouts({{0, 1, 2}, {0, 1, 2, 3}, {0, 2, 3, 1}, {0, 1, 2, 3, 4}});
return op.normalize_compute_shape(conv_inputs); return op.normalize_compute_shape(conv_inputs);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment