Commit fc5db1ad authored by Manupa Karunaratne's avatar Manupa Karunaratne
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into mlir-attention

parents 930b147c b249fb8a
......@@ -29,4 +29,4 @@ pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build
msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off
sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCmSoftwarePlatform/composable_kernel@70eefcf4f263aa5c25f3c9ff0db8f6f199ef0fb9 -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCmSoftwarePlatform/rocMLIR@507bb94ce7873786486d296ec81d2eadaab49003 -DBUILD_FAT_LIBROCKCOMPILER=On
\ No newline at end of file
ROCmSoftwarePlatform/rocMLIR@3700afd2564e21267a4d1fd8f1f80465f45daa93 -DBUILD_FAT_LIBROCKCOMPILER=On
......@@ -37,6 +37,8 @@ namespace op {
* Static allocate:
* No inputs: `allocate()`
* `this.s` attribute set to the static output shape of the buffer.
* `this.s` attribute can be set to a dynamic output shape; however this will allocate the maximum
* buffer size for that case
*
* Dynamic allocate:
* One input: `allocate(output_dims)`
......@@ -74,10 +76,6 @@ struct allocate
}
else
{
if(s->dynamic())
{
MIGRAPHX_THROW("ALLOCATE: dynamic shape attribute and no input");
}
migraphx::check_shapes{inputs, *this, false}.has(0);
}
return s.value();
......
......@@ -147,8 +147,8 @@ void quantize_int8(program& prog,
run_passes(prog,
{quantize_int8_pass{ins_names, *int8_quant_params},
optimize_module{},
simplify_qdq{},
optimize_module{},
dead_code_elimination{}});
}
......
cc7e8cc21f83df3a41d9736dba9211bb832764ad
2eeafc37bca21dc8bf337dda7020b486543162d7
......@@ -116,11 +116,12 @@ TEST_CASE(allocate_dyn_with_shape_attr)
input);
}
TEST_CASE(allocate_dyn_no_input_error)
TEST_CASE(allocate_dyn_no_input)
{
migraphx::shape shape_attr{migraphx::shape::float_type,
{{1, 4}, {3, 3}, {4, 8, {4, 6}}, {4, 8}, {4, 6}}};
throws_shape(migraphx::make_op("allocate", {{"shape", migraphx::to_value(shape_attr)}}));
expect_shape(shape_attr,
migraphx::make_op("allocate", {{"shape", migraphx::to_value(shape_attr)}}));
}
TEST_CASE(allocate_shape_and_buf_type_error)
......
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
......@@ -52,6 +52,12 @@ def parse_args():
parser.add_argument('--fill0',
action='store_true',
help='fill all arguments with a value of 0')
parser.add_argument('--fp16',
action='store_true',
help='quantize MIGraphX model to fp16')
parser.add_argument('--argmax',
action='store_true',
help='use argmax for accuracy')
parser.add_argument('--verbose',
action='store_true',
help='show verbose information (for debugging)')
......@@ -105,7 +111,7 @@ def parse_args():
args = parser.parse_args()
return args
return args, parser
# taken from ../test_runner.py
......@@ -113,6 +119,7 @@ def check_correctness(gold_outputs,
outputs,
rtol=1e-3,
atol=1e-3,
use_argmax=False,
verbose=False):
if len(gold_outputs) != len(outputs):
print('Number of outputs {} is not equal to expected number {}'.format(
......@@ -121,6 +128,8 @@ def check_correctness(gold_outputs,
out_num = len(gold_outputs)
ret = True
if not use_argmax:
for i in range(out_num):
if not np.allclose(gold_outputs[i], outputs[i], rtol, atol):
ret = False
......@@ -132,7 +141,16 @@ def check_correctness(gold_outputs,
else:
print('Outputs do not match')
break
else:
golden_argmax = np.argmax(gold_outputs)
actual_argmax = np.argmax(outputs)
if actual_argmax != golden_argmax:
ret = False
print('\nOutput argmax is incorrect ...')
if verbose:
print('Expected argmax value: \n{}'.format(golden_argmax))
print('......')
print('Actual argmax value: \n{}\n'.format(actual_argmax))
return ret
......@@ -155,13 +173,14 @@ def get_np_datatype(in_type):
def main():
args = parse_args()
args, parser = parse_args()
use_onnx = True
if args.onnx == None:
use_onnx = False
if not use_onnx and args.tf == None:
print('Error: please specify either an onnx or tf pb file')
parser.print_help()
sys.exit(-1)
model_name = args.onnx
......@@ -194,6 +213,9 @@ def main():
batch_size=batch,
map_input_dims=input_dims)
if (args.fp16):
migraphx.quantize_fp16(model)
if args.verbose:
print(model)
......@@ -300,7 +322,8 @@ def main():
if not args.ort_run:
is_correct = check_correctness(pred_fw, pred_migx, args.tolerance,
args.tolerance, args.verbose)
args.tolerance, args.argmax,
args.verbose)
verbose_string = ' Rerun with --verbose for detailed information.' \
if not args.verbose else ''
if is_correct:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment