Commit b2f12dae authored by turneram's avatar turneram
Browse files

Merge remote-tracking branch 'origin/develop' into ck-gemm-int8

parents 05122cc7 ce2adafd
...@@ -49,7 +49,8 @@ migraphx::instruction_ref add_layernorm(migraphx::module& m, ...@@ -49,7 +49,8 @@ migraphx::instruction_ref add_layernorm(migraphx::module& m,
auto pow = m.add_instruction(migraphx::make_op("pow"), sub, exponent_mbcast); auto pow = m.add_instruction(migraphx::make_op("pow"), sub, exponent_mbcast);
auto var = m.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), pow); auto var = m.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), pow);
auto epsilon_mbcast = m.add_instruction( auto epsilon_mbcast = m.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, dims.at(1), 1}}}), epsilon); migraphx::make_op("multibroadcast", {{"out_lens", {dims.at(0), dims.at(1), 1}}}), epsilon);
auto add_epsilon = m.add_instruction(migraphx::make_op("add"), var, epsilon_mbcast); auto add_epsilon = m.add_instruction(migraphx::make_op("add"), var, epsilon_mbcast);
auto sqrt = m.add_instruction(migraphx::make_op("sqrt"), add_epsilon); auto sqrt = m.add_instruction(migraphx::make_op("sqrt"), add_epsilon);
auto sqrt_mbcast = auto sqrt_mbcast =
...@@ -57,7 +58,8 @@ migraphx::instruction_ref add_layernorm(migraphx::module& m, ...@@ -57,7 +58,8 @@ migraphx::instruction_ref add_layernorm(migraphx::module& m,
auto div = m.add_instruction(migraphx::make_op("div"), sub, sqrt_mbcast); auto div = m.add_instruction(migraphx::make_op("div"), sub, sqrt_mbcast);
auto scale_mbcast = auto scale_mbcast =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), scale); m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), scale);
auto mul = m.add_instruction(migraphx::make_op("mul"), scale_mbcast, div); auto mul = m.add_instruction(migraphx::make_op("mul"), div, scale_mbcast);
auto bias_mbcast = auto bias_mbcast =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), bias); m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), bias);
return m.add_instruction(migraphx::make_op("add"), mul, bias_mbcast); return m.add_instruction(migraphx::make_op("add"), mul, bias_mbcast);
...@@ -161,3 +163,21 @@ struct test_layernorm_triadd_large : verify_program<test_layernorm_triadd_large> ...@@ -161,3 +163,21 @@ struct test_layernorm_triadd_large : verify_program<test_layernorm_triadd_large>
return p; return p;
} }
}; };
struct test_add_layernorm_add_gemm_nonstd : verify_program<test_add_layernorm_add_gemm_nonstd>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto s =
migraphx::shape::from_permutation(migraphx::shape::float_type, {8, 1, 16}, {1, 2, 0});
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {8, 16, 64}});
auto add = mm->add_instruction(migraphx::make_op("add"), x, y);
auto layernorm_ins = add_layernorm(*mm, add, s.lens());
mm->add_instruction(migraphx::make_op("dot"), layernorm_ins, z);
return p;
}
};
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
##################################################################################### #####################################################################################
# The MIT License (MIT) # The MIT License (MIT)
# #
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
# #
# Permission is hereby granted, free of charge, to any person obtaining a copy # Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal # of this software and associated documentation files (the "Software"), to deal
...@@ -38,7 +38,7 @@ def getipynb_markdownBlockAsList(): ...@@ -38,7 +38,7 @@ def getipynb_markdownBlockAsList():
'\t\t"cell_type": "code",\n', '\t\t"execution_count": null,\n', '\t\t"cell_type": "code",\n', '\t\t"execution_count": null,\n',
'\t\t"metadata": {},\n', '\t\t"outputs": [],\n', '\t\t"source": [\n', '\t\t"metadata": {},\n', '\t\t"outputs": [],\n', '\t\t"source": [\n',
'\t\t\t\"# The MIT License (MIT)\\n\",\n', '\t\t\t\"#\\n\",\n', '\t\t\t\"# The MIT License (MIT)\\n\",\n', '\t\t\t\"#\\n\",\n',
'\t\t\t\"# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.\\n\",\n', '\t\t\t\"# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.\\n\",\n',
'\t\t\t\"#\\n\",\n', '\t\t\t\"#\\n\",\n',
'\t\t\t\"# Permission is hereby granted, free of charge, to any person obtaining a copy\\n\",\n', '\t\t\t\"# Permission is hereby granted, free of charge, to any person obtaining a copy\\n\",\n',
'\t\t\t\"# of this software and associated documentation files (the \'Software\'), to deal\\n\",\n', '\t\t\t\"# of this software and associated documentation files (the \'Software\'), to deal\\n\",\n',
......
...@@ -39,6 +39,15 @@ def parse_args(): ...@@ -39,6 +39,15 @@ def parse_args():
type=str, type=str,
default='gpu', default='gpu',
help='Specify where the tests execute (ref, gpu)') help='Specify where the tests execute (ref, gpu)')
parser.add_argument('--fp16', action='store_true', help='Quantize to fp16')
parser.add_argument('--atol',
type=float,
default=1e-3,
help='The absolute tolerance parameter')
parser.add_argument('--rtol',
type=float,
default=1e-3,
help='The relative tolerance parameter')
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -257,6 +266,8 @@ def main(): ...@@ -257,6 +266,8 @@ def main():
# read and compile model # read and compile model
model = migraphx.parse_onnx(model_path_name, map_input_dims=param_shapes) model = migraphx.parse_onnx(model_path_name, map_input_dims=param_shapes)
if args.fp16:
migraphx.quantize_fp16(model)
model.compile(migraphx.get_target(target)) model.compile(migraphx.get_target(target))
# get test cases # get test cases
...@@ -279,7 +290,10 @@ def main(): ...@@ -279,7 +290,10 @@ def main():
output_data = run_one_case(model, input_data) output_data = run_one_case(model, input_data)
# check output correctness # check output correctness
ret = check_correctness(gold_outputs, output_data) ret = check_correctness(gold_outputs,
output_data,
atol=args.atol,
rtol=args.rtol)
if ret: if ret:
correct_num += 1 correct_num += 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment