Unverified Commit 9d3fb0b5 authored by Ted Themistokleous's avatar Ted Themistokleous Committed by GitHub
Browse files

Merge branch 'develop' into enable_navi_32_ci

parents 9c91c08d aeb9f78c
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
# THE SOFTWARE. # THE SOFTWARE.
##################################################################################### #####################################################################################
file(GLOB VERIFY_TESTS ${CONFIGURE_DEPENDS} *.cpp) file(GLOB VERIFY_TESTS CONFIGURE_DEPENDS *.cpp)
add_executable(test_verify ${VERIFY_TESTS}) add_executable(test_verify ${VERIFY_TESTS})
add_dependencies(tests test_verify) add_dependencies(tests test_verify)
......
...@@ -88,10 +88,31 @@ inline void compile_check(migraphx::program& p, ...@@ -88,10 +88,31 @@ inline void compile_check(migraphx::program& p,
auto num = shapes.size(); auto num = shapes.size();
for(std::size_t i = 0; i < num; ++i) for(std::size_t i = 0; i < num; ++i)
{ {
if(p.get_output_shapes()[i].lens() != shapes[i].lens()) auto output_shape = p.get_output_shapes()[i];
if(output_shape.dynamic() and shapes[i].dynamic())
{
if(output_shape.dyn_dims() != shapes[i].dyn_dims())
{
std::cout << ss.str() << std::endl;
throw std::runtime_error("Compiling program with " + name +
" alters its dynamic output dimensions");
}
}
else if(not(output_shape.dynamic() or shapes[i].dynamic()))
{
if(output_shape.lens() != shapes[i].lens())
{
std::cout << ss.str() << std::endl;
throw std::runtime_error("Compiling program with " + name +
" alters its static output dimensions");
}
}
else
{ {
std::cout << ss.str() << std::endl; std::cout << ss.str() << std::endl;
throw std::runtime_error("Compiling program with " + name + " alters its shape"); throw std::runtime_error(
"Compiling program with " + name +
" alters its output dimensions (static shape vs dynamic shape)");
} }
} }
if(t.name() != "ref") if(t.name() != "ref")
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_add_nhwc : verify_program<test_add_nhwc>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape::from_permutation(
migraphx::shape::float_type, {4, 3, 8, 8}, {0, 2, 3, 1});
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add = mm->add_instruction(migraphx::make_op("add"), x, y);
mm->add_return({add});
return p;
}
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/op/pooling.hpp>
struct test_avg_pooling_pad : verify_program<test_avg_pooling_pad>
{
migraphx::program create_program() const
{
// pooling test with nonzero padding
migraphx::program p;
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 7}});
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average, {2}, {1}, {3}};
mm->add_instruction(op, input);
return p;
}
};
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_deconv : verify_program<test_deconv> struct test_convolution_backwards : verify_program<test_convolution_backwards>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -37,7 +37,7 @@ struct test_deconv : verify_program<test_deconv> ...@@ -37,7 +37,7 @@ struct test_deconv : verify_program<test_deconv>
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 3, 3}}); mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 3, 3}});
auto weights = auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {1, 1, 3, 3}}); mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {1, 1, 3, 3}});
mm->add_instruction(migraphx::make_op("deconvolution"), input, weights); mm->add_instruction(migraphx::make_op("convolution_backwards"), input, weights);
return p; return p;
} }
}; };
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_deconv_1d : verify_program<test_deconv_1d> struct test_convolution_backwards_1d : verify_program<test_convolution_backwards_1d>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -38,7 +38,7 @@ struct test_deconv_1d : verify_program<test_deconv_1d> ...@@ -38,7 +38,7 @@ struct test_deconv_1d : verify_program<test_deconv_1d>
auto weights = auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {1, 1, 3}}); mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {1, 1, 3}});
mm->add_instruction( mm->add_instruction(
migraphx::make_op("deconvolution", migraphx::make_op("convolution_backwards",
{{"padding", {0}}, {"stride", {1}}, {"dilation", {1}}}), {{"padding", {0}}, {"stride", {1}}, {"dilation", {1}}}),
input, input,
weights); weights);
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_convolution_backwards_2d_alt : verify_program<test_convolution_backwards_2d_alt>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 10, 10}});
auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {1, 1, 3, 3}});
mm->add_instruction(
migraphx::make_op("convolution_backwards",
{{"padding", {2, 2}}, {"stride", {2, 2}}, {"dilation", {2, 2}}}),
input,
weights);
return p;
}
};
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_deconv_2x3 : verify_program<test_deconv_2x3> struct test_convolution_backwards_2x3 : verify_program<test_convolution_backwards_2x3>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -38,7 +38,7 @@ struct test_deconv_2x3 : verify_program<test_deconv_2x3> ...@@ -38,7 +38,7 @@ struct test_deconv_2x3 : verify_program<test_deconv_2x3>
auto weights = auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {3, 4, 3, 3}}); mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {3, 4, 3, 3}});
mm->add_instruction( mm->add_instruction(
migraphx::make_op("deconvolution", migraphx::make_op("convolution_backwards",
{{"padding", {1, 1}}, {"stride", {2, 3}}, {"dilation", {1, 1}}}), {{"padding", {1, 1}}, {"stride", {2, 3}}, {"dilation", {1, 1}}}),
input, input,
weights); weights);
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_deconv_3d : verify_program<test_deconv_3d> struct test_convolution_backwards_3d : verify_program<test_convolution_backwards_3d>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -39,7 +39,7 @@ struct test_deconv_3d : verify_program<test_deconv_3d> ...@@ -39,7 +39,7 @@ struct test_deconv_3d : verify_program<test_deconv_3d>
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {1, 1, 3, 3, 3}}); mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {1, 1, 3, 3, 3}});
mm->add_instruction( mm->add_instruction(
migraphx::make_op( migraphx::make_op(
"deconvolution", "convolution_backwards",
{{"padding", {0, 0, 0}}, {"stride", {1, 1, 1}}, {"dilation", {1, 1, 1}}}), {{"padding", {0, 0, 0}}, {"stride", {1, 1, 1}}, {"dilation", {1, 1, 1}}}),
input, input,
weights); weights);
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
struct test_reduce_mean_nhwc : verify_program<test_reduce_mean_nhwc>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape::from_permutation(
migraphx::shape::float_type, {4, 256, 2, 2}, {0, 2, 3, 1});
auto x = mm->add_parameter("x", s);
auto reduce = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {1}}}), x);
auto abs = mm->add_instruction(migraphx::make_op("abs"), reduce);
auto sqrt = mm->add_instruction(migraphx::make_op("sqrt"), abs);
mm->add_return({sqrt});
return p;
};
};
...@@ -36,6 +36,8 @@ error_type = '' ...@@ -36,6 +36,8 @@ error_type = ''
success_type = '' success_type = ''
try_wrap = '' try_wrap = ''
export_c_macro = 'MIGRAPHX_C_EXPORT'
c_header_preamble: List[str] = [] c_header_preamble: List[str] = []
c_api_body_preamble: List[str] = [] c_api_body_preamble: List[str] = []
cpp_header_preamble: List[str] = [] cpp_header_preamble: List[str] = []
...@@ -125,7 +127,7 @@ class Type: ...@@ -125,7 +127,7 @@ class Type:
header_function = Template(''' header_function = Template('''
${error_type} ${name}(${params}); ${export_c_macro} ${error_type} ${name}(${params});
''') ''')
function_pointer_typedef = Template(''' function_pointer_typedef = Template('''
...@@ -177,7 +179,7 @@ class CFunction: ...@@ -177,7 +179,7 @@ class CFunction:
**kwargs) **kwargs)
def generate_header(self) -> str: def generate_header(self) -> str:
return self.substitute(header_function) return self.substitute(header_function, export_c_macro=export_c_macro)
def generate_function_pointer(self, name: Optional[str] = None) -> str: def generate_function_pointer(self, name: Optional[str] = None) -> str:
return self.substitute(function_pointer_typedef, return self.substitute(function_pointer_typedef,
......
...@@ -44,7 +44,7 @@ namespace migraphx { ...@@ -44,7 +44,7 @@ namespace migraphx {
static thread_local bool disable_exception_catch = false; // NOLINT static thread_local bool disable_exception_catch = false; // NOLINT
extern "C" void migraphx_test_private_disable_exception_catch(bool b) extern "C" MIGRAPHX_C_EXPORT void migraphx_test_private_disable_exception_catch(bool b)
{ {
disable_exception_catch = b; disable_exception_catch = b;
} }
......
...@@ -26,6 +26,9 @@ ...@@ -26,6 +26,9 @@
#include <stdlib.h> #include <stdlib.h>
#include <stdbool.h> #include <stdbool.h>
#include <migraphx/api/export.h>
// Add new types here // Add new types here
// clang-format off // clang-format off
#define MIGRAPHX_SHAPE_VISIT_TYPES(m) \ #define MIGRAPHX_SHAPE_VISIT_TYPES(m) \
......
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
import os, shutil, argparse, subprocess
CLANG_FORMAT_PATH = '/opt/rocm/llvm/bin'
def run(cmd, **kwargs):
print(cmd)
subprocess.run(cmd, shell=True, check=True, **kwargs)
def eval(cmd, **kwargs):
return subprocess.run(cmd,
capture_output=True,
shell=True,
check=True,
**kwargs).stdout.decode('utf-8').strip()
def get_top():
return eval("git rev-parse --show-toplevel")
def get_head():
return eval("git rev-parse --abbrev-ref HEAD")
def get_merge_base(branch):
head = get_head()
return eval(f"git merge-base {branch} {head}")
def clang_format(against, apply=False, path=CLANG_FORMAT_PATH):
base = get_merge_base(against)
clang_format = os.path.join(path, 'clang-format')
if not os.path.exists(clang_format):
print(f"{clang_format} not installed. Skipping format.")
return
git_clang_format = os.path.join(path, 'git-clang-format')
if not os.path.exists(git_clang_format):
print(f"{git_clang_format} not installed. Skipping format.")
return
diff_flag = "" if apply else "--diff"
run(f"{git_clang_format} --binary {clang_format} {diff_flag} {base}")
def get_files_changed(against, ext=('py')):
files = eval(f"git diff-index --cached --name-only {against}",
cwd=get_top()).splitlines()
return (f for f in files if f.endswith(ext))
def yapf_format(against, apply=False):
if not shutil.which('yapf'):
print("yapf not installed. Skipping format.")
return
diff_flag = "--in-place" if apply else "--diff"
files = ' '.join(get_files_changed(against))
if files:
run(f"yapf {diff_flag} -p {files}")
else:
print("No modified python files to format")
def main():
parser = argparse.ArgumentParser()
parser.add_argument('against', default='develop', nargs='?')
parser.add_argument('-i', '--in-place', action='store_true')
parser.add_argument('-q', '--quiet', action='store_true')
args = parser.parse_args()
try:
clang_format(args.against, apply=args.in_place)
yapf_format(args.against, apply=args.in_place)
except subprocess.CalledProcessError as ex:
if ex.stdout:
print(ex.stdout.decode('utf-8'))
if ex.stderr:
print(ex.stderr.decode('utf-8'))
if not args.quiet:
print(f"Command '{ex.cmd}' returned {ex.returncode}")
raise
# sys.exit(ex.returncode)
if __name__ == "__main__":
main()
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
# THE SOFTWARE. # THE SOFTWARE.
##################################################################################### #####################################################################################
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
CLANG_FORMAT=/opt/rocm/llvm/bin/clang-format
SRC_DIR=$DIR/../src SRC_DIR=$DIR/../src
PYTHON=python3 PYTHON=python3
if type -p python3.6 > /dev/null ; then if type -p python3.6 > /dev/null ; then
...@@ -30,10 +31,10 @@ fi ...@@ -30,10 +31,10 @@ fi
if type -p python3.8 > /dev/null ; then if type -p python3.8 > /dev/null ; then
PYTHON=python3.8 PYTHON=python3.8
fi fi
ls -1 $DIR/include/ | xargs -n 1 -P $(nproc) -I{} -t bash -c "$PYTHON $DIR/te.py $DIR/include/{} | clang-format-10 -style=file > $SRC_DIR/include/migraphx/{}" ls -1 $DIR/include/ | xargs -n 1 -P $(nproc) -I{} -t bash -c "$PYTHON $DIR/te.py $DIR/include/{} | $CLANG_FORMAT -style=file > $SRC_DIR/include/migraphx/{}"
function api { function api {
$PYTHON $DIR/api.py $SRC_DIR/api/migraphx.py $1 | clang-format-10 -style=file > $2 $PYTHON $DIR/api.py $SRC_DIR/api/migraphx.py $1 | $CLANG_FORMAT -style=file > $2
} }
api $DIR/api/migraphx.h $SRC_DIR/api/include/migraphx/migraphx.h api $DIR/api/migraphx.h $SRC_DIR/api/include/migraphx/migraphx.h
......
...@@ -143,7 +143,7 @@ auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs) ...@@ -143,7 +143,7 @@ auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs)
if(inputs.empty()) if(inputs.empty())
MIGRAPHX_THROW("At least one input is required for " + x.name()); MIGRAPHX_THROW("At least one input is required for " + x.name());
dependent_type<operation, T> y = x; dependent_type<operation, T> y = x;
normalize_attributes(y, inputs[0].max_lens()); normalize_attributes(y, inputs[0]);
return any_cast<T>(y).normalize_compute_shape(inputs); return any_cast<T>(y).normalize_compute_shape(inputs);
} }
...@@ -675,8 +675,8 @@ bool has_finalize(const T& x) ...@@ -675,8 +675,8 @@ bool has_finalize(const T& x)
return detail::has_finalize_op(x); return detail::has_finalize_op(x);
} }
void migraphx_to_value(value& v, const operation& op); MIGRAPHX_EXPORT void migraphx_to_value(value& v, const operation& op);
void migraphx_from_value(const value& v, operation& op); MIGRAPHX_EXPORT void migraphx_from_value(const value& v, operation& op);
#endif #endif
......
...@@ -57,7 +57,7 @@ struct pass ...@@ -57,7 +57,7 @@ struct pass
#else #else
module& get_module(module_pass_manager& mpm); MIGRAPHX_EXPORT module& get_module(module_pass_manager& mpm);
namespace detail { namespace detail {
......
...@@ -28,6 +28,8 @@ trivial = [ ...@@ -28,6 +28,8 @@ trivial = [
'bool', 'any_ptr' 'bool', 'any_ptr'
] ]
export_macro = 'MIGRAPHX_EXPORT'
headers = ''' headers = '''
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
...@@ -41,7 +43,7 @@ form = string.Template(''' ...@@ -41,7 +43,7 @@ form = string.Template('''
#ifdef TYPE_ERASED_DECLARATION #ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for: // Type-erased interface for:
struct ${struct_name} struct ${export_macro} ${struct_name}
{ {
${decl_members} ${decl_members}
}; };
...@@ -68,7 +70,7 @@ struct ${struct_name} ...@@ -68,7 +70,7 @@ struct ${struct_name}
{ {
using std::swap; using std::swap;
auto * derived = this->any_cast<PrivateDetailTypeErasedT>(); auto * derived = this->any_cast<PrivateDetailTypeErasedT>();
if(derived and private_detail_te_handle_mem_var.unique()) if(derived and private_detail_te_handle_mem_var.use_count() == 1)
{ {
*derived = std::forward<PrivateDetailTypeErasedT>(value); *derived = std::forward<PrivateDetailTypeErasedT>(value);
} }
...@@ -179,7 +181,7 @@ private: ...@@ -179,7 +181,7 @@ private:
private_detail_te_handle_base_type & private_detail_te_get_handle () private_detail_te_handle_base_type & private_detail_te_get_handle ()
{ {
assert(private_detail_te_handle_mem_var != nullptr); assert(private_detail_te_handle_mem_var != nullptr);
if (not private_detail_te_handle_mem_var.unique()) if (private_detail_te_handle_mem_var.use_count() > 1)
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone(); private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var; return *private_detail_te_handle_mem_var;
} }
...@@ -395,7 +397,8 @@ def generate_form(name, members): ...@@ -395,7 +397,8 @@ def generate_form(name, members):
default_members=''.join(default_members), default_members=''.join(default_members),
decl_members=''.join(decl_members), decl_members=''.join(decl_members),
comment_members='\n'.join(comment_members), comment_members='\n'.join(comment_members),
struct_name=name) struct_name=name,
export_macro=export_macro)
def virtual(name, returns=None, **kwargs): def virtual(name, returns=None, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment