Commit 264a7647 authored by Brian Pickrell's avatar Brian Pickrell
Browse files

Merge branch 'develop' into multinomial_parse_merge

parents d99729f8 8e18544f
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_convolution_backwards_2d_alt : verify_program<test_convolution_backwards_2d_alt>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 10, 10}});
auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {1, 1, 3, 3}});
mm->add_instruction(
migraphx::make_op("convolution_backwards",
{{"padding", {2, 2}}, {"stride", {2, 2}}, {"dilation", {2, 2}}}),
input,
weights);
return p;
}
};
......@@ -27,7 +27,7 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_deconv_2x3 : verify_program<test_deconv_2x3>
struct test_convolution_backwards_2x3 : verify_program<test_convolution_backwards_2x3>
{
migraphx::program create_program() const
{
......@@ -38,7 +38,7 @@ struct test_deconv_2x3 : verify_program<test_deconv_2x3>
auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {3, 4, 3, 3}});
mm->add_instruction(
migraphx::make_op("deconvolution",
migraphx::make_op("convolution_backwards",
{{"padding", {1, 1}}, {"stride", {2, 3}}, {"dilation", {1, 1}}}),
input,
weights);
......
......@@ -27,7 +27,7 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_deconv_3d : verify_program<test_deconv_3d>
struct test_convolution_backwards_3d : verify_program<test_convolution_backwards_3d>
{
migraphx::program create_program() const
{
......@@ -39,7 +39,7 @@ struct test_deconv_3d : verify_program<test_deconv_3d>
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {1, 1, 3, 3, 3}});
mm->add_instruction(
migraphx::make_op(
"deconvolution",
"convolution_backwards",
{{"padding", {0, 0, 0}}, {"stride", {1, 1, 1}}, {"dilation", {1, 1, 1}}}),
input,
weights);
......
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
import os, shutil, argparse, subprocess
CLANG_FORMAT_PATH = '/opt/rocm/llvm/bin'
def run(cmd, **kwargs):
print(cmd)
subprocess.run(cmd, shell=True, check=True, **kwargs)
def eval(cmd, **kwargs):
return subprocess.run(cmd,
capture_output=True,
shell=True,
check=True,
**kwargs).stdout.decode('utf-8').strip()
def get_top():
return eval("git rev-parse --show-toplevel")
def get_head():
return eval("git rev-parse --abbrev-ref HEAD")
def get_merge_base(branch):
head = get_head()
return eval(f"git merge-base {branch} {head}")
def clang_format(against, apply=False, path=CLANG_FORMAT_PATH):
base = get_merge_base(against)
clang_format = os.path.join(path, 'clang-format')
if not os.path.exists(clang_format):
print(f"{clang_format} not installed. Skipping format.")
return
git_clang_format = os.path.join(path, 'git-clang-format')
if not os.path.exists(git_clang_format):
print(f"{git_clang_format} not installed. Skipping format.")
return
diff_flag = "" if apply else "--diff"
run(f"{git_clang_format} --binary {clang_format} {diff_flag} {base}")
def get_files_changed(against, ext=('py')):
files = eval(f"git diff-index --cached --name-only {against}",
cwd=get_top()).splitlines()
return (f for f in files if f.endswith(ext))
def yapf_format(against, apply=False):
if not shutil.which('yapf'):
print("yapf not installed. Skipping format.")
return
diff_flag = "--in-place" if apply else "--diff"
files = ' '.join(get_files_changed(against))
if files:
run(f"yapf {diff_flag} -p {files}")
else:
print("No modified python files to format")
def main():
parser = argparse.ArgumentParser()
parser.add_argument('against', default='develop', nargs='?')
parser.add_argument('-i', '--in-place', action='store_true')
parser.add_argument('-q', '--quiet', action='store_true')
args = parser.parse_args()
try:
clang_format(args.against, apply=args.in_place)
yapf_format(args.against, apply=args.in_place)
except subprocess.CalledProcessError as ex:
if ex.stdout:
print(ex.stdout.decode('utf-8'))
if ex.stderr:
print(ex.stderr.decode('utf-8'))
if not args.quiet:
print(f"Command '{ex.cmd}' returned {ex.returncode}")
raise
# sys.exit(ex.returncode)
if __name__ == "__main__":
main()
......@@ -22,6 +22,7 @@
# THE SOFTWARE.
#####################################################################################
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
CLANG_FORMAT=/opt/rocm/llvm/bin/clang-format
SRC_DIR=$DIR/../src
PYTHON=python3
if type -p python3.6 > /dev/null ; then
......@@ -30,10 +31,10 @@ fi
if type -p python3.8 > /dev/null ; then
PYTHON=python3.8
fi
ls -1 $DIR/include/ | xargs -n 1 -P $(nproc) -I{} -t bash -c "$PYTHON $DIR/te.py $DIR/include/{} | clang-format-10 -style=file > $SRC_DIR/include/migraphx/{}"
ls -1 $DIR/include/ | xargs -n 1 -P $(nproc) -I{} -t bash -c "$PYTHON $DIR/te.py $DIR/include/{} | $CLANG_FORMAT -style=file > $SRC_DIR/include/migraphx/{}"
function api {
$PYTHON $DIR/api.py $SRC_DIR/api/migraphx.py $1 | clang-format-10 -style=file > $2
$PYTHON $DIR/api.py $SRC_DIR/api/migraphx.py $1 | $CLANG_FORMAT -style=file > $2
}
api $DIR/api/migraphx.h $SRC_DIR/api/include/migraphx/migraphx.h
......
......@@ -143,7 +143,7 @@ auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs)
if(inputs.empty())
MIGRAPHX_THROW("At least one input is required for " + x.name());
dependent_type<operation, T> y = x;
normalize_attributes(y, inputs[0].max_lens());
normalize_attributes(y, inputs[0]);
return any_cast<T>(y).normalize_compute_shape(inputs);
}
......@@ -261,11 +261,13 @@ auto compute_op(rank<1>,
template <class T, class F>
argument compute_op(rank<0>,
const T& x,
const shape&,
const std::vector<argument>&,
const std::vector<module_ref>&,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F)
{
if(module_args.empty())
return compute_op(x, output, inputs);
std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name);
}
......@@ -673,8 +675,8 @@ bool has_finalize(const T& x)
return detail::has_finalize_op(x);
}
void migraphx_to_value(value& v, const operation& op);
void migraphx_from_value(const value& v, operation& op);
MIGRAPHX_EXPORT void migraphx_to_value(value& v, const operation& op);
MIGRAPHX_EXPORT void migraphx_from_value(const value& v, operation& op);
#endif
......
......@@ -57,7 +57,7 @@ struct pass
#else
module& get_module(module_pass_manager& mpm);
MIGRAPHX_EXPORT module& get_module(module_pass_manager& mpm);
namespace detail {
......
......@@ -45,6 +45,8 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct value;
#ifdef DOXYGEN
/// An interface for a compilation target
......@@ -123,28 +125,41 @@ supported_segments target_find_supported(T&, const_module_ref, support_metric)
}
<%
interface('target',
virtual('name', returns='std::string', const=True),
virtual('get_passes', ctx='context&', options='const compile_options&', returns='std::vector<pass>', const=True),
virtual('get_context', returns='context', const=True),
virtual('find_supported', returns='supported_segments', mod='const_module_ref', m='support_metric', const=True, default='target_find_supported'),
virtual('copy_to',
returns = 'argument',
input = 'const argument&',
const = True,
default = 'copy_to_target'),
virtual('copy_from',
returns = 'argument',
input = 'const argument&',
const = True,
default = 'copy_from_target'),
virtual('allocate', s='const shape&', returns='argument', const=True,
default = 'target_allocate')
)
%>
interface('target',
virtual('name', returns = 'std::string', const = True),
virtual('get_passes',
ctx = 'context&',
options = 'const compile_options&',
returns = 'std::vector<pass>',
const = True),
virtual('get_context', returns = 'context', const = True),
virtual('find_supported',
returns = 'supported_segments',
mod = 'const_module_ref',
m = 'support_metric',
const = True,
default = 'target_find_supported'),
virtual('copy_to',
returns = 'argument',
input = 'const argument&',
const = True,
default = 'copy_to_target'),
virtual('copy_from',
returns = 'argument',
input = 'const argument&',
const = True,
default = 'copy_from_target'),
virtual('allocate',
s = 'const shape&',
returns = 'argument',
const = True,
default = 'target_allocate')) %>
#endif
void migraphx_to_value(value& v, const target& t);
void migraphx_from_value(const value& v, target& t);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -28,6 +28,8 @@ trivial = [
'bool', 'any_ptr'
]
export_macro = 'MIGRAPHX_EXPORT'
headers = '''
#include <algorithm>
#include <cassert>
......@@ -41,7 +43,7 @@ form = string.Template('''
#ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for:
struct ${struct_name}
struct ${export_macro} ${struct_name}
{
${decl_members}
};
......@@ -68,7 +70,7 @@ struct ${struct_name}
{
using std::swap;
auto * derived = this->any_cast<PrivateDetailTypeErasedT>();
if(derived and private_detail_te_handle_mem_var.unique())
if(derived and private_detail_te_handle_mem_var.use_count() == 1)
{
*derived = std::forward<PrivateDetailTypeErasedT>(value);
}
......@@ -179,7 +181,7 @@ private:
private_detail_te_handle_base_type & private_detail_te_get_handle ()
{
assert(private_detail_te_handle_mem_var != nullptr);
if (not private_detail_te_handle_mem_var.unique())
if (private_detail_te_handle_mem_var.use_count() > 1)
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
......@@ -395,7 +397,8 @@ def generate_form(name, members):
default_members=''.join(default_members),
decl_members=''.join(decl_members),
comment_members='\n'.join(comment_members),
struct_name=name)
struct_name=name,
export_macro=export_macro)
def virtual(name, returns=None, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment