Commit fc5db1ad authored by Manupa Karunaratne's avatar Manupa Karunaratne
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into mlir-attention

parents 930b147c b249fb8a
...@@ -29,4 +29,4 @@ pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build ...@@ -29,4 +29,4 @@ pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build
msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off
sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCmSoftwarePlatform/composable_kernel@70eefcf4f263aa5c25f3c9ff0db8f6f199ef0fb9 -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCmSoftwarePlatform/composable_kernel@70eefcf4f263aa5c25f3c9ff0db8f6f199ef0fb9 -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCmSoftwarePlatform/rocMLIR@507bb94ce7873786486d296ec81d2eadaab49003 -DBUILD_FAT_LIBROCKCOMPILER=On ROCmSoftwarePlatform/rocMLIR@3700afd2564e21267a4d1fd8f1f80465f45daa93 -DBUILD_FAT_LIBROCKCOMPILER=On
\ No newline at end of file
...@@ -37,6 +37,8 @@ namespace op { ...@@ -37,6 +37,8 @@ namespace op {
* Static allocate: * Static allocate:
* No inputs: `allocate()` * No inputs: `allocate()`
* `this.s` attribute set to the static output shape of the buffer. * `this.s` attribute set to the static output shape of the buffer.
* `this.s` attribute can be set to a dynamic output shape; however this will allocate the maximum
* buffer size for that case
* *
* Dynamic allocate: * Dynamic allocate:
* One input: `allocate(output_dims)` * One input: `allocate(output_dims)`
...@@ -74,10 +76,6 @@ struct allocate ...@@ -74,10 +76,6 @@ struct allocate
} }
else else
{ {
if(s->dynamic())
{
MIGRAPHX_THROW("ALLOCATE: dynamic shape attribute and no input");
}
migraphx::check_shapes{inputs, *this, false}.has(0); migraphx::check_shapes{inputs, *this, false}.has(0);
} }
return s.value(); return s.value();
......
...@@ -147,8 +147,8 @@ void quantize_int8(program& prog, ...@@ -147,8 +147,8 @@ void quantize_int8(program& prog,
run_passes(prog, run_passes(prog,
{quantize_int8_pass{ins_names, *int8_quant_params}, {quantize_int8_pass{ins_names, *int8_quant_params},
optimize_module{},
simplify_qdq{}, simplify_qdq{},
optimize_module{},
dead_code_elimination{}}); dead_code_elimination{}});
} }
......
cc7e8cc21f83df3a41d9736dba9211bb832764ad 2eeafc37bca21dc8bf337dda7020b486543162d7
...@@ -116,11 +116,12 @@ TEST_CASE(allocate_dyn_with_shape_attr) ...@@ -116,11 +116,12 @@ TEST_CASE(allocate_dyn_with_shape_attr)
input); input);
} }
TEST_CASE(allocate_dyn_no_input_error) TEST_CASE(allocate_dyn_no_input)
{ {
migraphx::shape shape_attr{migraphx::shape::float_type, migraphx::shape shape_attr{migraphx::shape::float_type,
{{1, 4}, {3, 3}, {4, 8, {4, 6}}, {4, 8}, {4, 6}}}; {{1, 4}, {3, 3}, {4, 8, {4, 6}}, {4, 8}, {4, 6}}};
throws_shape(migraphx::make_op("allocate", {{"shape", migraphx::to_value(shape_attr)}})); expect_shape(shape_attr,
migraphx::make_op("allocate", {{"shape", migraphx::to_value(shape_attr)}}));
} }
TEST_CASE(allocate_shape_and_buf_type_error) TEST_CASE(allocate_shape_and_buf_type_error)
......
##################################################################################### #####################################################################################
# The MIT License (MIT) # The MIT License (MIT)
# #
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
# #
# Permission is hereby granted, free of charge, to any person obtaining a copy # Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal # of this software and associated documentation files (the "Software"), to deal
...@@ -52,6 +52,12 @@ def parse_args(): ...@@ -52,6 +52,12 @@ def parse_args():
parser.add_argument('--fill0', parser.add_argument('--fill0',
action='store_true', action='store_true',
help='fill all arguments with a value of 0') help='fill all arguments with a value of 0')
parser.add_argument('--fp16',
action='store_true',
help='quantize MIGraphX model to fp16')
parser.add_argument('--argmax',
action='store_true',
help='use argmax for accuracy')
parser.add_argument('--verbose', parser.add_argument('--verbose',
action='store_true', action='store_true',
help='show verbose information (for debugging)') help='show verbose information (for debugging)')
...@@ -105,7 +111,7 @@ def parse_args(): ...@@ -105,7 +111,7 @@ def parse_args():
args = parser.parse_args() args = parser.parse_args()
return args return args, parser
# taken from ../test_runner.py # taken from ../test_runner.py
...@@ -113,6 +119,7 @@ def check_correctness(gold_outputs, ...@@ -113,6 +119,7 @@ def check_correctness(gold_outputs,
outputs, outputs,
rtol=1e-3, rtol=1e-3,
atol=1e-3, atol=1e-3,
use_argmax=False,
verbose=False): verbose=False):
if len(gold_outputs) != len(outputs): if len(gold_outputs) != len(outputs):
print('Number of outputs {} is not equal to expected number {}'.format( print('Number of outputs {} is not equal to expected number {}'.format(
...@@ -121,6 +128,8 @@ def check_correctness(gold_outputs, ...@@ -121,6 +128,8 @@ def check_correctness(gold_outputs,
out_num = len(gold_outputs) out_num = len(gold_outputs)
ret = True ret = True
if not use_argmax:
for i in range(out_num): for i in range(out_num):
if not np.allclose(gold_outputs[i], outputs[i], rtol, atol): if not np.allclose(gold_outputs[i], outputs[i], rtol, atol):
ret = False ret = False
...@@ -132,7 +141,16 @@ def check_correctness(gold_outputs, ...@@ -132,7 +141,16 @@ def check_correctness(gold_outputs,
else: else:
print('Outputs do not match') print('Outputs do not match')
break break
else:
golden_argmax = np.argmax(gold_outputs)
actual_argmax = np.argmax(outputs)
if actual_argmax != golden_argmax:
ret = False
print('\nOutput argmax is incorrect ...')
if verbose:
print('Expected argmax value: \n{}'.format(golden_argmax))
print('......')
print('Actual argmax value: \n{}\n'.format(actual_argmax))
return ret return ret
...@@ -155,13 +173,14 @@ def get_np_datatype(in_type): ...@@ -155,13 +173,14 @@ def get_np_datatype(in_type):
def main(): def main():
args = parse_args() args, parser = parse_args()
use_onnx = True use_onnx = True
if args.onnx == None: if args.onnx == None:
use_onnx = False use_onnx = False
if not use_onnx and args.tf == None: if not use_onnx and args.tf == None:
print('Error: please specify either an onnx or tf pb file') print('Error: please specify either an onnx or tf pb file')
parser.print_help()
sys.exit(-1) sys.exit(-1)
model_name = args.onnx model_name = args.onnx
...@@ -194,6 +213,9 @@ def main(): ...@@ -194,6 +213,9 @@ def main():
batch_size=batch, batch_size=batch,
map_input_dims=input_dims) map_input_dims=input_dims)
if (args.fp16):
migraphx.quantize_fp16(model)
if args.verbose: if args.verbose:
print(model) print(model)
...@@ -300,7 +322,8 @@ def main(): ...@@ -300,7 +322,8 @@ def main():
if not args.ort_run: if not args.ort_run:
is_correct = check_correctness(pred_fw, pred_migx, args.tolerance, is_correct = check_correctness(pred_fw, pred_migx, args.tolerance,
args.tolerance, args.verbose) args.tolerance, args.argmax,
args.verbose)
verbose_string = ' Rerun with --verbose for detailed information.' \ verbose_string = ' Rerun with --verbose for detailed information.' \
if not args.verbose else '' if not args.verbose else ''
if is_correct: if is_correct:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment