Commit f85ba189 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'pointwise-nhwc' of...

Merge branch 'pointwise-nhwc' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into nhwc_workaround
parents 122ffe97 dfbab16e
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_convolution_backwards_2d_alt : verify_program<test_convolution_backwards_2d_alt>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 10, 10}});
auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {1, 1, 3, 3}});
mm->add_instruction(
migraphx::make_op("convolution_backwards",
{{"padding", {2, 2}}, {"stride", {2, 2}}, {"dilation", {2, 2}}}),
input,
weights);
return p;
}
};
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_deconv_2x3 : verify_program<test_deconv_2x3> struct test_convolution_backwards_2x3 : verify_program<test_convolution_backwards_2x3>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -38,7 +38,7 @@ struct test_deconv_2x3 : verify_program<test_deconv_2x3> ...@@ -38,7 +38,7 @@ struct test_deconv_2x3 : verify_program<test_deconv_2x3>
auto weights = auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {3, 4, 3, 3}}); mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {3, 4, 3, 3}});
mm->add_instruction( mm->add_instruction(
migraphx::make_op("deconvolution", migraphx::make_op("convolution_backwards",
{{"padding", {1, 1}}, {"stride", {2, 3}}, {"dilation", {1, 1}}}), {{"padding", {1, 1}}, {"stride", {2, 3}}, {"dilation", {1, 1}}}),
input, input,
weights); weights);
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_deconv_3d : verify_program<test_deconv_3d> struct test_convolution_backwards_3d : verify_program<test_convolution_backwards_3d>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -39,7 +39,7 @@ struct test_deconv_3d : verify_program<test_deconv_3d> ...@@ -39,7 +39,7 @@ struct test_deconv_3d : verify_program<test_deconv_3d>
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {1, 1, 3, 3, 3}}); mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {1, 1, 3, 3, 3}});
mm->add_instruction( mm->add_instruction(
migraphx::make_op( migraphx::make_op(
"deconvolution", "convolution_backwards",
{{"padding", {0, 0, 0}}, {"stride", {1, 1, 1}}, {"dilation", {1, 1, 1}}}), {{"padding", {0, 0, 0}}, {"stride", {1, 1, 1}}, {"dilation", {1, 1, 1}}}),
input, input,
weights); weights);
......
...@@ -675,8 +675,8 @@ bool has_finalize(const T& x) ...@@ -675,8 +675,8 @@ bool has_finalize(const T& x)
return detail::has_finalize_op(x); return detail::has_finalize_op(x);
} }
void migraphx_to_value(value& v, const operation& op); MIGRAPHX_EXPORT void migraphx_to_value(value& v, const operation& op);
void migraphx_from_value(const value& v, operation& op); MIGRAPHX_EXPORT void migraphx_from_value(const value& v, operation& op);
#endif #endif
......
...@@ -57,7 +57,7 @@ struct pass ...@@ -57,7 +57,7 @@ struct pass
#else #else
module& get_module(module_pass_manager& mpm); MIGRAPHX_EXPORT module& get_module(module_pass_manager& mpm);
namespace detail { namespace detail {
......
...@@ -28,6 +28,8 @@ trivial = [ ...@@ -28,6 +28,8 @@ trivial = [
'bool', 'any_ptr' 'bool', 'any_ptr'
] ]
export_macro = 'MIGRAPHX_EXPORT'
headers = ''' headers = '''
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
...@@ -41,7 +43,7 @@ form = string.Template(''' ...@@ -41,7 +43,7 @@ form = string.Template('''
#ifdef TYPE_ERASED_DECLARATION #ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for: // Type-erased interface for:
struct ${struct_name} struct ${export_macro} ${struct_name}
{ {
${decl_members} ${decl_members}
}; };
...@@ -395,7 +397,8 @@ def generate_form(name, members): ...@@ -395,7 +397,8 @@ def generate_form(name, members):
default_members=''.join(default_members), default_members=''.join(default_members),
decl_members=''.join(decl_members), decl_members=''.join(decl_members),
comment_members='\n'.join(comment_members), comment_members='\n'.join(comment_members),
struct_name=name) struct_name=name,
export_macro=export_macro)
def virtual(name, returns=None, **kwargs): def virtual(name, returns=None, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment