Commit 4ea39116 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

manual merge

parents 20128cae d8011adf
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/common.hpp>
template <migraphx::shape::type_t T>
struct test_shrink : verify_program<test_shrink<T>>
{
migraphx::program create_program() const
{
migraphx::program p;
float bias = 1.5;
float lambd = 1.5;
auto* mm = p.get_main_module();
migraphx::shape is{T, {2, 3}};
std::vector<float> data;
migraphx::shape::visit(T, [&](auto as) {
as.is_signed() ? data.assign({-3.0, -2.0, -1.0, 0.0, 1.0, 2.0})
: data.assign({3.0, 2.0, 1.0, 0.0, 1.0, 2.0});
});
auto x = mm->add_literal(migraphx::literal{is, data});
auto lit_bias = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {bias}});
auto lit_neg_lambd =
mm->add_literal(migraphx::literal{migraphx::shape::float_type, {-lambd}});
auto lit_lambd = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {lambd}});
auto x_plus_bias = add_common_op(*mm, migraphx::make_op("add"), {x, lit_bias});
auto x_min_bias = add_common_op(*mm, migraphx::make_op("sub"), {x, lit_bias});
auto cond1 = add_common_op(*mm, migraphx::make_op("less"), {x, lit_neg_lambd});
auto cond2_a = add_common_op(*mm, migraphx::make_op("not"), {cond1});
auto cond2_b = add_common_op(*mm, migraphx::make_op("greater"), {x, lit_lambd});
auto cond2 = add_common_op(*mm, migraphx::make_op("logical_and"), {cond2_a, cond2_b});
auto mul1 = mm->add_instruction(migraphx::make_op("convert", {{"target_type", T}}), cond1);
auto mul2 = mm->add_instruction(migraphx::make_op("convert", {{"target_type", T}}), cond2);
auto first = add_common_op(*mm, migraphx::make_op("mul"), {mul1, x_plus_bias});
auto second = add_common_op(*mm, migraphx::make_op("mul"), {mul2, x_min_bias});
auto ret = add_common_op(*mm, migraphx::make_op("add"), {first, second});
if(ret->get_shape().type() != T)
{
mm->add_instruction(migraphx::make_op("convert", {{"target_type", T}}), ret);
}
return p;
}
};
template struct test_shrink<migraphx::shape::double_type>;
template struct test_shrink<migraphx::shape::float_type>;
template struct test_shrink<migraphx::shape::half_type>;
template struct test_shrink<migraphx::shape::int64_type>;
template struct test_shrink<migraphx::shape::int32_type>;
template struct test_shrink<migraphx::shape::int16_type>;
template struct test_shrink<migraphx::shape::int8_type>;
template struct test_shrink<migraphx::shape::uint64_type>;
template struct test_shrink<migraphx::shape::uint32_type>;
template struct test_shrink<migraphx::shape::uint16_type>;
template struct test_shrink<migraphx::shape::uint8_type>;
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_squeeze_conv_relu : verify_program<test_squeeze_conv_relu>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 1, 3, 3}});
input = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), input);
auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights);
mm->add_instruction(migraphx::make_op("relu"), conv);
return p;
}
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_unsqueeze_conv_relu : verify_program<test_unsqueeze_conv_relu>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 3, 3}});
auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {3, 3, 3}});
weights = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), weights);
auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights);
mm->add_instruction(migraphx::make_op("relu"), conv);
return p;
}
};
......@@ -22,4 +22,16 @@
# THE SOFTWARE.
#####################################################################################
add_custom_target(generate bash ${CMAKE_CURRENT_SOURCE_DIR}/generate.sh)
find_package(Python 3 COMPONENTS Interpreter)
if(NOT Python_EXECUTABLE)
message(WARNING "Python 3 interpreter not found - skipping 'generate' target!")
return()
endif()
find_program(CLANG_FORMAT clang-format PATHS /opt/rocm/llvm ENV HIP_PATH PATH_SUFFIXES bin)
if(NOT CLANG_FORMAT)
message(WARNING "clang-format not found - skipping 'generate' target!")
return()
endif()
add_custom_target(generate ${Python_EXECUTABLE} generate.py -f ${CLANG_FORMAT} WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
......@@ -52,6 +52,12 @@ def parse_args():
parser.add_argument('--fill0',
action='store_true',
help='fill all arguments with a value of 0')
parser.add_argument('--fp16',
action='store_true',
help='quantize MIGraphX model to fp16')
parser.add_argument('--argmax',
action='store_true',
help='use argmax for accuracy')
parser.add_argument('--verbose',
action='store_true',
help='show verbose information (for debugging)')
......@@ -105,7 +111,7 @@ def parse_args():
args = parser.parse_args()
return args
return args, parser
# taken from ../test_runner.py
......@@ -113,6 +119,7 @@ def check_correctness(gold_outputs,
outputs,
rtol=1e-3,
atol=1e-3,
use_argmax=False,
verbose=False):
if len(gold_outputs) != len(outputs):
print('Number of outputs {} is not equal to expected number {}'.format(
......@@ -121,18 +128,29 @@ def check_correctness(gold_outputs,
out_num = len(gold_outputs)
ret = True
for i in range(out_num):
if not np.allclose(gold_outputs[i], outputs[i], rtol, atol):
if not use_argmax:
for i in range(out_num):
if not np.allclose(gold_outputs[i], outputs[i], rtol, atol):
ret = False
if verbose:
print('\nOutput {} is incorrect ...'.format(i))
print('Expected value: \n{}'.format(gold_outputs[i]))
print('......')
print('Actual value: \n{}\n'.format(outputs[i]))
else:
print('Outputs do not match')
break
else:
golden_argmax = np.argmax(gold_outputs)
actual_argmax = np.argmax(outputs)
if actual_argmax != golden_argmax:
ret = False
print('\nOutput argmax is incorrect ...')
if verbose:
print('\nOutput {} is incorrect ...'.format(i))
print('Expected value: \n{}'.format(gold_outputs[i]))
print('Expected argmax value: \n{}'.format(golden_argmax))
print('......')
print('Actual value: \n{}\n'.format(outputs[i]))
else:
print('Outputs do not match')
break
print('Actual argmax value: \n{}\n'.format(actual_argmax))
return ret
......@@ -155,13 +173,14 @@ def get_np_datatype(in_type):
def main():
args = parse_args()
args, parser = parse_args()
use_onnx = True
if args.onnx == None:
use_onnx = False
if not use_onnx and args.tf == None:
print('Error: please specify either an onnx or tf pb file')
parser.print_help()
sys.exit(-1)
model_name = args.onnx
......@@ -194,6 +213,9 @@ def main():
batch_size=batch,
map_input_dims=input_dims)
if (args.fp16):
migraphx.quantize_fp16(model)
if args.verbose:
print(model)
......@@ -220,10 +242,16 @@ def main():
else:
test_input = np.zeros(in_shape).astype(get_np_datatype(in_type))
test_inputs[name] = test_input
params[name] = migraphx.argument(test_input)
migraphx_arg = migraphx.argument(test_input)
if not args.offload_copy:
migraphx_arg = migraphx.to_gpu(migraphx_arg)
params[name] = migraphx_arg
if not args.ort_run:
pred_migx = np.array(model.run(params)[-1])
if not args.offload_copy:
pred_migx = np.array(migraphx.from_gpu(model.run(params)[-1]))
else:
pred_migx = np.array(model.run(params)[-1])
if use_onnx:
sess_op = ort.SessionOptions()
......@@ -294,7 +322,8 @@ def main():
if not args.ort_run:
is_correct = check_correctness(pred_fw, pred_migx, args.tolerance,
args.tolerance, args.verbose)
args.tolerance, args.argmax,
args.verbose)
verbose_string = ' Rerun with --verbose for detailed information.' \
if not args.verbose else ''
if is_correct:
......
......@@ -22,4 +22,4 @@
# THE SOFTWARE.
#####################################################################################
numpy==1.21.6
onnxruntime==1.10.0
onnxruntime==1.16.2
......@@ -27,6 +27,7 @@ import re
import runpy
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from pathlib import Path
type_map: Dict[str, Callable[['Parameter'], None]] = {}
cpp_type_map: Dict[str, str] = {}
......@@ -1281,18 +1282,17 @@ def template_eval(template, **kwargs):
return template
def run(args: List[str]) -> None:
runpy.run_path(args[0])
if len(args) > 1:
f = open(args[1]).read()
r = template_eval(f)
def run(path: Union[Path, str]) -> str:
return template_eval(open(path).read())
if __name__ == "__main__":
sys.modules['api'] = sys.modules['__main__']
runpy.run_path(sys.argv[1])
if len(sys.argv) > 2:
r = run(sys.argv[2])
sys.stdout.write(r)
else:
sys.stdout.write(generate_c_header())
sys.stdout.write(generate_c_api_body())
# sys.stdout.write(generate_cpp_header())
if __name__ == "__main__":
sys.modules['api'] = sys.modules['__main__']
run(sys.argv[1:])
......@@ -38,26 +38,32 @@
#include <migraphx/register_op.hpp>
#include <migraphx/json.hpp>
#include <migraphx/convert_to_json.hpp>
#include <array>
#include <algorithm>
#include <cstdarg>
namespace migraphx {
#ifdef MIGRAPHX_BUILD_TESTING
static thread_local bool disable_exception_catch = false; // NOLINT
extern "C" MIGRAPHX_C_EXPORT void migraphx_test_private_disable_exception_catch(bool b)
{
disable_exception_catch = b;
}
#endif
template <class F>
migraphx_status try_(F f, bool output = true) // NOLINT
{
#ifdef MIGRAPHX_BUILD_TESTING
if(disable_exception_catch)
{
f();
}
else
{
#endif
try
{
f();
......@@ -81,7 +87,9 @@ migraphx_status try_(F f, bool output = true) // NOLINT
{
return migraphx_status_unknown_error;
}
#ifdef MIGRAPHX_BUILD_TESTING
}
#endif
return migraphx_status_success;
}
......@@ -156,6 +164,11 @@ void set_default_loop_iterations(onnx_options& options, int64_t value)
options.max_loop_iterations = value;
}
void set_limit_loop_iterations(onnx_options& options, int64_t value)
{
options.limit_max_iterations = value;
}
void set_nhwc(tf_options& options, bool is_nhwc) { options.is_nhwc = is_nhwc; }
void set_default_dim_value(tf_options& options, size_t value) { options.batch_size = value; }
......
......@@ -26,6 +26,7 @@
#include <stdlib.h>
#include <stdbool.h>
#include <stdint.h>
#include <migraphx/api/export.h>
......
......@@ -40,4 +40,4 @@ echo 'InferenceSessionTests.CheckRunProfilerWithSessionOptions' >> ../../../tool
echo 'InferenceSessionTests.CheckRunProfilerWithSessionOptions2' >> ../../../tools/ci_build/github/pai/migraphx-excluded-tests.txt
echo 'InferenceSessionTests.Test3LayerNestedSubgraph' >> ../../../tools/ci_build/github/pai/migraphx-excluded-tests.txt
echo 'InferenceSessionTests.Test2LayerNestedSubgraph' >> ../../../tools/ci_build/github/pai/migraphx-excluded-tests.txt
../../../tools/ci_build/github/pai/migraphx_test_launcher.sh || (gdb ./onnxruntime_test_all core -batch -ex bt && exit 1)
../../../tools/ci_build/github/pai/pai_test_launcher.sh || (gdb ./onnxruntime_test_all core -batch -ex bt && exit 1)
......@@ -2,7 +2,7 @@
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
......@@ -27,11 +27,11 @@ import sys
debug = False
# The filetypes we want to check for that are stamped
# LICENSE is included here as it SHOULD have a liscence in it otherwise flag it as unstamped
# LICENSE is included here as it SHOULD have a license in it otherwise flag it as unstamped
supported_file_types = (".cpp", ".hpp", ".h", ".ipynb", ".py", ".txt", ".sh",
".bsh", "LICENSE", ".cmake")
#add general stuff we shouldn't stamp and any exceptions here
# add general stuff we shouldn't stamp and any exceptions here
unsupported_file_types = [
".onnx", ".pb", ".rst", ".jpg", ".jpeg", ".proto", ".md", ".clang",
".weight", ".ini", ".json", ".docker", ".git", ".rules", ".yml"
......@@ -40,105 +40,89 @@ unsupported_file_types = [
specificIgnores = ("digits.txt", "Dockerfile", "Jenkinsfile", "")
def hasKeySequence(inputfile, key_message):
result = False
def hasKeySequence(inputfile: str, key_message: str) -> bool:
if key_message in inputfile:
result = True
return result
return True
return False
#Simple just open and write stuff to each file with the license stamp
def openAndCheckFile(filename):
result = False
#open save old contents and append things here
if debug is True:
print("Open", filename, end='')
# Simple just open and write stuff to each file with the license stamp
def needStampCheck(filename: str) -> bool:
# open save old contents and append things here
if debug: print("Open", filename, end=' ')
try:
file = open(filename, 'r')
except OSError as e:
if debug is True:
print(str(e) + "....Open Error: Skipping file ")
if debug: print(str(e) + "....Open Error: Skipping file ")
file.close()
return
return False
else:
with file as contents:
try:
save = contents.read()
hasAmdLic = hasKeySequence(
save, "Advanced Micro Devices, Inc. All rights reserved")
#Check if we have a licence stamp already
if hasAmdLic is True:
if debug is True:
print("....Already Stamped: Skipping file ")
# Check if we have a license stamp already
if hasKeySequence(
save,
"Advanced Micro Devices, Inc. All rights reserved"):
if debug: print("....Already Stamped: Skipping file ")
contents.close()
result = True
return False
except UnicodeDecodeError as eu:
if debug is True:
print(str(eu) + "...Skipping binary file ")
if debug: print(f"{str(eu)}...Skipping binary file ")
contents.close()
result = True
return False
return result
return True
# Deterine if filename is desired in the fileTuple past in
def check_filename(filename, fileTuple):
supported = False
for key in fileTuple:
if key in filename:
supported = True
break
return supported
# Check if any element in fileTuple is in filename
def check_filename(filename: str, fileTuple: tuple or list) -> bool:
if any([x in filename for x in fileTuple]):
return True
return False
def main():
def main() -> None:
unsupported_file_types.extend(specificIgnores)
#Get a list of all the tracked files in our git repo
# Get a list of all the tracked files in our git repo
proc = subprocess.run("git ls-files --exclude-standard",
shell=True,
stdout=subprocess.PIPE)
fileList = proc.stdout.decode().split('\n')
if debug is True:
print("Target file list:\n" + str(fileList))
if debug: print("Target file list:\n" + str(fileList))
unsupportedFiles = []
unstampedFiles = []
unknownFiles = []
for file in fileList:
supported = check_filename(file, supported_file_types)
if supported is True:
isStamped = openAndCheckFile(file)
if isStamped is False:
if check_filename(file, supported_file_types):
if needStampCheck(file):
unstampedFiles.append(file)
elif check_filename(file, unsupported_file_types):
unsupportedFiles.append(file)
else:
unsupported = check_filename(file, unsupported_file_types)
if unsupported is True:
unsupportedFiles.append(file)
else:
unknownFiles.append(file)
unknownFiles.append(file)
#Do a bunch of checks based on our file lists
# Do a bunch of checks based on our file lists
if len(unstampedFiles) > 0:
print("Error: The following " + str(len(unstampedFiles)) +
print("\nError: The following " + str(len(unstampedFiles)) +
" files are currently without a license:")
print(str(unstampedFiles))
sys.exit(1)
if len(unknownFiles) > 0:
print("Error: The following " + str(len(unknownFiles)) +
print("\nError: The following " + str(len(unknownFiles)) +
" files not handled:")
print(str(unknownFiles))
sys.exit(2)
sys.exit(0)
if __name__ == "__main__":
main()
......@@ -90,11 +90,6 @@ RUN pip3 install yapf==0.28.0
ADD docs/.sphinx/requirements.txt /doc-requirements.txt
RUN pip3 install -r /doc-requirements.txt
# Download real models to run onnx unit tests
ENV ONNX_HOME=/.onnx
COPY ./tools/download_models.sh /
RUN /download_models.sh && rm /download_models.sh
# Install latest ccache version
RUN cget -p $PREFIX install facebook/zstd@v1.4.5 -X subdir -DCMAKE_DIR=build/cmake
RUN cget -p $PREFIX install ccache@v4.1 -DENABLE_TESTING=OFF
......
#!/bin/bash
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
......@@ -23,34 +21,68 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
import api, argparse, os, runpy, subprocess, sys, te
from pathlib import Path
clang_format_path = Path('clang-format.exe' if os.name ==
'nt' else '/opt/rocm/llvm/bin/clang-format')
work_dir = Path().cwd()
src_dir = (work_dir / '../src').absolute()
migraphx_py_path = src_dir / 'api/migraphx.py'
def clang_format(buffer, **kwargs):
return subprocess.run(f'{clang_format_path} -style=file',
capture_output=True,
shell=True,
check=True,
input=buffer.encode('utf-8'),
cwd=work_dir,
**kwargs).stdout.decode('utf-8')
def api_generate(input_path: Path, output_path: Path):
with open(output_path, 'w') as f:
f.write(clang_format(api.run(input_path)))
def te_generate(input_path: Path, output_path: Path):
with open(output_path, 'w') as f:
f.write(clang_format(te.run(input_path)))
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-f', '--clang-format', type=Path)
args = parser.parse_args()
global clang_format_path
if args.clang_format:
clang_format_path = args.clang_format
if not clang_format_path.is_file():
print(f"{clang_format_path}: invalid path or not installed",
file=sys.stderr)
return
try:
files = Path('include').absolute().iterdir()
for f in [f for f in files if f.is_file()]:
te_generate(f, src_dir / f'include/migraphx/{f.name}')
runpy.run_path(str(migraphx_py_path))
api_generate(work_dir / 'api/migraphx.h',
src_dir / 'api/include/migraphx/migraphx.h')
print('Finished generating header migraphx.h')
api_generate(work_dir / 'api/api.cpp', src_dir / 'api/api.cpp')
print('Finished generating source api.cpp')
except subprocess.CalledProcessError as ex:
if ex.stdout:
print(ex.stdout.decode('utf-8'))
if ex.stderr:
print(ex.stdout.decode('utf-8'))
print(f"Command '{ex.cmd}' returned {ex.returncode}")
raise
set -e
if [ -z "$ONNX_HOME" ]
then
# The onnx library uses ONNX_HOME, by default if it doesn't exist
# the path of " ~/.onnx " is used
ONNX_HOME=$HOME/.onnx
fi
model_dir=$ONNX_HOME/models
tmp_dir=$ONNX_HOME/tmp/
mkdir -p $model_dir
mkdir -p $tmp_dir
models="bvlc_alexnet \
densenet121 \
inception_v2 \
shufflenet \
vgg19 \
zfnet512"
for name in $models
do
curl https://download.onnxruntime.ai/onnx/models/$name.tar.gz --output $tmp_dir/$name.tar.gz
tar -xzvf $tmp_dir/$name.tar.gz --directory $model_dir && rm $tmp_dir/$name.tar.gz
done
# CI jobs can run as a different user then the docker image builder.
# Allow read/write access to the models
chmod 777 $model_dir
if __name__ == "__main__":
main()
......@@ -3,7 +3,7 @@
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
......@@ -51,6 +51,7 @@ else
openmp-extras \
python3-dev \
python3-pip \
python3-venv \
rocblas-dev \
rocm-cmake
fi
......
......@@ -431,6 +431,9 @@ def template_eval(template, **kwargs):
return template
f = open(sys.argv[1]).read()
r = template_eval(f)
sys.stdout.write(r)
def run(p):
return template_eval(open(p).read())
if __name__ == '__main__':
sys.stdout.write(run(sys.argv[1]))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment