Unverified Commit a24ed87e authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Merge branch 'develop' into optimize_jenkinsfile

parents 6481cd69 a09dc502
......@@ -27,13 +27,14 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_topk_0 : verify_program<test_topk_0>
template <migraphx::shape::type_t DType>
struct test_topk_0 : verify_program<test_topk_0<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3, 5}};
migraphx::shape s{DType, {3, 5}};
auto data = mm->add_parameter("data", s);
auto r = mm->add_instruction(
migraphx::make_op("topk", {{"axis", 1}, {"k", 4}, {"largest", 1}}), data);
......@@ -43,3 +44,7 @@ struct test_topk_0 : verify_program<test_topk_0>
return p;
}
};
template struct test_topk_0<migraphx::shape::float_type>;
template struct test_topk_0<migraphx::shape::half_type>;
template struct test_topk_0<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -27,15 +27,17 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/apply_alpha_beta.hpp>
struct test_unbatched_gemm_1 : verify_program<test_unbatched_gemm_1>
template <migraphx::shape::type_t DType>
struct test_unbatched_gemm_1 : verify_program<test_unbatched_gemm_1<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 32, 64}};
migraphx::shape m2_shape{migraphx::shape::float_type, {64, 64}};
migraphx::shape m3_shape{migraphx::shape::float_type, {2, 32, 192}};
migraphx::shape m1_shape{DType, {2, 32, 64}};
migraphx::shape m2_shape{DType, {64, 64}};
migraphx::shape m3_shape{DType, {2, 32, 192}};
auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_literal(migraphx::generate_literal(m2_shape));
l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 64, 64}}}),
......@@ -56,3 +58,7 @@ struct test_unbatched_gemm_1 : verify_program<test_unbatched_gemm_1>
return p;
}
};
template struct test_unbatched_gemm_1<migraphx::shape::float_type>;
template struct test_unbatched_gemm_1<migraphx::shape::half_type>;
template struct test_unbatched_gemm_1<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -27,14 +27,16 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/apply_alpha_beta.hpp>
struct test_unbatched_gemm_2 : verify_program<test_unbatched_gemm_2>
template <migraphx::shape::type_t DType>
struct test_unbatched_gemm_2 : verify_program<test_unbatched_gemm_2<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {4, 32, 64}};
migraphx::shape m2_shape{migraphx::shape::float_type, {64, 64}};
migraphx::shape m1_shape{DType, {4, 32, 64}};
migraphx::shape m2_shape{DType, {64, 64}};
auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_literal(migraphx::generate_literal(m2_shape));
l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {4, 64, 64}}}),
......@@ -44,3 +46,7 @@ struct test_unbatched_gemm_2 : verify_program<test_unbatched_gemm_2>
return p;
}
};
template struct test_unbatched_gemm_2<migraphx::shape::float_type>;
template struct test_unbatched_gemm_2<migraphx::shape::half_type>;
template struct test_unbatched_gemm_2<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -27,7 +27,8 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_where : verify_program<test_where>
template <migraphx::shape::type_t DType>
struct test_where : verify_program<test_where<DType>>
{
migraphx::program create_program() const
{
......@@ -44,3 +45,7 @@ struct test_where : verify_program<test_where>
return p;
};
};
template struct test_where<migraphx::shape::float_type>;
template struct test_where<migraphx::shape::half_type>;
template struct test_where<migraphx::shape::fp8e4m3fnuz_type>;
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
......@@ -52,6 +52,12 @@ def parse_args():
parser.add_argument('--fill0',
action='store_true',
help='fill all arguments with a value of 0')
parser.add_argument('--fp16',
action='store_true',
help='quantize MIGraphX model to fp16')
parser.add_argument('--argmax',
action='store_true',
help='use argmax for accuracy')
parser.add_argument('--verbose',
action='store_true',
help='show verbose information (for debugging)')
......@@ -105,7 +111,7 @@ def parse_args():
args = parser.parse_args()
return args
return args, parser
# taken from ../test_runner.py
......@@ -113,6 +119,7 @@ def check_correctness(gold_outputs,
outputs,
rtol=1e-3,
atol=1e-3,
use_argmax=False,
verbose=False):
if len(gold_outputs) != len(outputs):
print('Number of outputs {} is not equal to expected number {}'.format(
......@@ -121,18 +128,30 @@ def check_correctness(gold_outputs,
out_num = len(gold_outputs)
ret = True
if not use_argmax:
for i in range(out_num):
if not np.allclose(gold_outputs[i], outputs[i], rtol, atol):
ret = False
if verbose:
with np.printoptions(threshold=np.inf):
print('\nOutput {} is incorrect ...'.format(i))
print('Expected value: \n{}'.format(gold_outputs[i]))
print('......')
print('Expected value: \n{}\n'.format(gold_outputs[i]))
print('\n......\n')
print('Actual value: \n{}\n'.format(outputs[i]))
else:
print('Outputs do not match')
break
else:
golden_argmax = np.argmax(gold_outputs)
actual_argmax = np.argmax(outputs)
if actual_argmax != golden_argmax:
ret = False
print('\nOutput argmax is incorrect ...')
if verbose:
print('Expected argmax value: \n{}'.format(golden_argmax))
print('......')
print('Actual argmax value: \n{}\n'.format(actual_argmax))
return ret
......@@ -155,13 +174,14 @@ def get_np_datatype(in_type):
def main():
args = parse_args()
args, parser = parse_args()
use_onnx = True
if args.onnx == None:
use_onnx = False
if not use_onnx and args.tf == None:
print('Error: please specify either an onnx or tf pb file')
parser.print_help()
sys.exit(-1)
model_name = args.onnx
......@@ -194,6 +214,9 @@ def main():
batch_size=batch,
map_input_dims=input_dims)
if (args.fp16):
migraphx.quantize_fp16(model)
if args.verbose:
print(model)
......@@ -300,7 +323,8 @@ def main():
if not args.ort_run:
is_correct = check_correctness(pred_fw, pred_migx, args.tolerance,
args.tolerance, args.verbose)
args.tolerance, args.argmax,
args.verbose)
verbose_string = ' Rerun with --verbose for detailed information.' \
if not args.verbose else ''
if is_correct:
......
......@@ -22,4 +22,4 @@
# THE SOFTWARE.
#####################################################################################
numpy==1.21.6
onnxruntime==1.16.1
onnxruntime==1.16.3
......@@ -164,6 +164,11 @@ void set_default_loop_iterations(onnx_options& options, int64_t value)
options.max_loop_iterations = value;
}
void set_limit_loop_iterations(onnx_options& options, int64_t value)
{
options.limit_max_iterations = value;
}
void set_nhwc(tf_options& options, bool is_nhwc) { options.is_nhwc = is_nhwc; }
void set_default_dim_value(tf_options& options, size_t value) { options.batch_size = value; }
......
......@@ -44,7 +44,8 @@
m(int32_type, int32_t) \
m(int64_type, int64_t) \
m(uint32_type, uint32_t) \
m(uint64_type, uint64_t)
m(uint64_type, uint64_t) \
m(fp8e4m3fnuz_type, migraphx::fp8::fp8e4m3fnuz)
// clang-format on
#ifdef __cplusplus
......@@ -70,7 +71,9 @@ typedef enum
} migraphx_shape_datatype_t;
#undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES
<% generate_c_header() %>
<%
generate_c_header()
%>
#ifdef __cplusplus
}
......
......@@ -90,11 +90,6 @@ RUN pip3 install yapf==0.28.0
ADD docs/.sphinx/requirements.txt /doc-requirements.txt
RUN pip3 install -r /doc-requirements.txt
# Download real models to run onnx unit tests
ENV ONNX_HOME=/.onnx
COPY ./tools/download_models.sh /
RUN /download_models.sh && rm /download_models.sh
# Install latest ccache version
RUN cget -p $PREFIX install facebook/zstd@v1.4.5 -X subdir -DCMAKE_DIR=build/cmake
RUN cget -p $PREFIX install ccache@v4.1 -DENABLE_TESTING=OFF
......
......@@ -63,7 +63,8 @@ def clang_format(against, apply=False, path=CLANG_FORMAT_PATH):
print(f"{git_clang_format} not installed. Skipping format.")
return
diff_flag = "" if apply else "--diff"
run(f"{git_clang_format} --binary {clang_format} {diff_flag} {base}")
run(f"{git_clang_format} --extensions c,cpp,hpp,h,cl,hip,in --binary {clang_format} {diff_flag} {base}"
)
def get_files_changed(against, ext=('py')):
......
......@@ -53,7 +53,8 @@ else
python3-pip \
python3-venv \
rocblas-dev \
rocm-cmake
rocm-cmake \
libtbb-dev
fi
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment