Commit 1a430735 authored by Paul's avatar Paul
Browse files

Add support for wingrad fusions

parent d8bf45cf
...@@ -2,8 +2,5 @@ pfultz2/rocm-recipes ...@@ -2,8 +2,5 @@ pfultz2/rocm-recipes
pcre pcre
danmar/cppcheck@f965e5873 -DHAVE_RULES=1 danmar/cppcheck@f965e5873 -DHAVE_RULES=1
ROCm-Developer-Tools/HIP@3a41f286203968421c557338d6fb39c36f3c717c ROCm-Developer-Tools/HIP@3a41f286203968421c557338d6fb39c36f3c717c
# Needed for clang-ocl
RadeonOpenCompute/rocm-cmake@6240bb3 --build
RadeonOpenCompute/clang-ocl@799713643b5591a3b877c586ef2c7fbc012af819
# python/cpython@v3.6.6 -X autotools -H sha256:92aa914572c695c0aeb01b0a214813f414da4b51a371234df514a74761f2bb36 # python/cpython@v3.6.6 -X autotools -H sha256:92aa914572c695c0aeb01b0a214813f414da4b51a371234df514a74761f2bb36
-f requirements.txt -f requirements.txt
...@@ -45,6 +45,12 @@ bool contains(const C& c, const T& x) ...@@ -45,6 +45,12 @@ bool contains(const C& c, const T& x)
return generic_find(c, x) != c.end(); return generic_find(c, x) != c.end();
} }
template <class T>
bool contains(const std::initializer_list<T>& c, const T& x)
{
return generic_find(c, x) != c.end();
}
template <class T, class U> template <class T, class U>
bool contains(const std::initializer_list<T>& c, const U& x) bool contains(const std::initializer_list<T>& c, const U& x)
{ {
......
...@@ -134,15 +134,13 @@ MIGRAPH_PRED_MATCHER(fusable_conv, instruction_ref ins) ...@@ -134,15 +134,13 @@ MIGRAPH_PRED_MATCHER(fusable_conv, instruction_ref ins)
return false; return false;
auto wei = ins->inputs().at(1)->get_shape(); auto wei = ins->inputs().at(1)->get_shape();
assert(wei.lens().size() == 4); assert(wei.lens().size() == 4);
auto channels = wei.lens()[1] * wei.lens()[0];
if(wei.lens()[0] > 64 and channels > 32768)
return false;
auto conv = any_cast<miopen_convolution>(ins->get_operator()); auto conv = any_cast<miopen_convolution>(ins->get_operator());
if(conv.algo == miopenConvolutionFwdAlgoWinograd) if(wei.lens()[1] > 512 and conv.algo != miopenConvolutionFwdAlgoWinograd)
return false; return false;
auto op = conv.op; auto op = conv.op;
return op.padding == make_array<size_t>(0, 0) and op.stride == make_array<size_t>(1, 1) and return contains({{0, 0}, {1, 1}, {2, 2}}, op.padding) and
op.dilation == make_array<size_t>(1, 1); contains({{0, 0}, {1, 1}}, op.stride) and
op.dilation == make_array<size_t>(1, 1);
} }
struct hip_triadd struct hip_triadd
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment