"vscode:/vscode.git/clone" did not exist on "80120f0a0c524d1efc0249926a73d5020f0efd67"
Commit 0a6bc406 authored by umangyadav's avatar umangyadav
Browse files

Add logic to convert fp16 in parse instance norm, add flags for numpy in accuracy checker

parent 88fb551c
...@@ -39,22 +39,39 @@ struct parse_instancenorm : op_parser<parse_instancenorm> ...@@ -39,22 +39,39 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
instruction_ref parse(const op_desc& opd, instruction_ref parse(const op_desc& opd,
const onnx_parser& parser, const onnx_parser& parser,
onnx_parser::node_info info, onnx_parser::node_info info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> oargs) const
{ {
// y = scale * ( x - mean ) / sqrt ( variance + epsilon ) + bias // y = scale * ( x - mean ) / sqrt ( variance + epsilon ) + bias
// mean = reduce_mean({D1, D2, ... Dk}, x) // mean = reduce_mean({D1, D2, ... Dk}, x)
// variance = reduce_mean({D1, D2, ... Dk}, (x - mean)^2) // variance = reduce_mean({D1, D2, ... Dk}, (x - mean)^2)
bool convert_fp16 = true;
float epsilon = 1e-5f; float epsilon = 1e-5f;
if(contains(info.attributes, "epsilon")) if(contains(info.attributes, "epsilon"))
{ {
epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>(); epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>();
} }
auto dtype = oargs[0]->get_shape().type();
auto literal_dtype = dtype;
std::vector<instruction_ref> args;
if(dtype == shape::half_type and convert_fp16)
{
args.push_back(info.add_instruction(
make_op("convert", {{"target_type", shape::float_type}}), oargs[0]));
args.push_back(info.add_instruction(
make_op("convert", {{"target_type", shape::float_type}}), oargs[1]));
args.push_back(info.add_instruction(
make_op("convert", {{"target_type", shape::float_type}}), oargs[2]));
literal_dtype = shape::float_type;
}
else
{
args = oargs;
}
auto x = args[0]; auto x = args[0];
auto scale = args[1]; auto scale = args[1];
auto bias = args[2]; auto bias = args[2];
auto dims = x->get_shape().lens(); auto dims = x->get_shape().lens();
auto dtype = x->get_shape().type();
if(not contains(valid_types, dtype)) if(not contains(valid_types, dtype))
MIGRAPHX_THROW(opd.op_name + ": invalid output type: " + std::to_string(dtype) + MIGRAPHX_THROW(opd.op_name + ": invalid output type: " + std::to_string(dtype) +
". Valid types are 1 (float), 10 (half), and 11 (double)."); ". Valid types are 1 (float), 10 (half), and 11 (double).");
...@@ -72,7 +89,7 @@ struct parse_instancenorm : op_parser<parse_instancenorm> ...@@ -72,7 +89,7 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
auto l0 = info.add_instruction(make_op("sqdiff"), x, mean_bcast); auto l0 = info.add_instruction(make_op("sqdiff"), x, mean_bcast);
auto variance = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), l0); auto variance = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), l0);
auto l1 = info.add_instruction(make_op("sub"), x, mean_bcast); auto l1 = info.add_instruction(make_op("sub"), x, mean_bcast);
auto epsilon_literal = info.add_literal(literal{shape{dtype}, {epsilon}}); auto epsilon_literal = info.add_literal(literal{shape{literal_dtype}, {epsilon}});
auto epsilon_bcast = auto epsilon_bcast =
info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), epsilon_literal); info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), epsilon_literal);
auto variance_bcast = auto variance_bcast =
...@@ -85,8 +102,14 @@ struct parse_instancenorm : op_parser<parse_instancenorm> ...@@ -85,8 +102,14 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
; ;
auto bias_bcast = auto bias_bcast =
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), bias); info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), bias);
auto l5 = info.add_instruction(make_op("mul"), l4, scale_bcast); auto l5 = info.add_instruction(make_op("mul"), l4, scale_bcast);
return info.add_instruction(make_op("add"), l5, bias_bcast); auto ret = info.add_instruction(make_op("add"), l5, bias_bcast);
if(dtype == shape::half_type and convert_fp16)
{
return info.add_instruction(make_op("convert", {{"target_type", shape::half_type}}),
ret);
}
return ret;
} }
}; };
......
...@@ -52,6 +52,8 @@ def parse_args(): ...@@ -52,6 +52,8 @@ def parse_args():
parser.add_argument('--fill0', parser.add_argument('--fill0',
action='store_true', action='store_true',
help='fill all arguments with a value of 0') help='fill all arguments with a value of 0')
parser.add_argument('--numpy', action='append', help='fill argument with numpy saved array', type=str)
parser.add_argument('--np_path', action='append', help='Path for the saved numpy array', type=str)
parser.add_argument('--verbose', parser.add_argument('--verbose',
action='store_true', action='store_true',
help='show verbose information (for debugging)') help='show verbose information (for debugging)')
...@@ -81,7 +83,7 @@ def parse_args(): ...@@ -81,7 +83,7 @@ def parse_args():
action='store_true', action='store_true',
default=False, default=False,
help='Turn on ort VERBOSE logging via session options') help='Turn on ort VERBOSE logging via session options')
parser.add_argument('--save-ort-res', dest="save_ort", type=str, help='Save output of ORT as numpy array at path provided by this argument')
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -186,7 +188,12 @@ def main(): ...@@ -186,7 +188,12 @@ def main():
print(f'Parameter {name} -> {shape}') print(f'Parameter {name} -> {shape}')
in_shape = shape.lens() in_shape = shape.lens()
in_type = shape.type_string() in_type = shape.type_string()
if not args.fill1 and not args.fill0: if name in args.numpy:
in_path = args.np_path[args.numpy.index(name)]
if args.verbose:
print("Loading numpy array for input name : {name}, from path : {in_path}")
test_input = np.load(in_path).astype(get_np_datatype(in_type))
elif not args.fill1 and not args.fill0:
test_input = np.random.rand(*(in_shape)).astype( test_input = np.random.rand(*(in_shape)).astype(
get_np_datatype(in_type)) get_np_datatype(in_type))
elif not args.fill0: elif not args.fill0:
...@@ -216,6 +223,10 @@ def main(): ...@@ -216,6 +223,10 @@ def main():
try: try:
pred_fw = sess.run(None, ort_params)[-1] pred_fw = sess.run(None, ort_params)[-1]
if(args.save_ort):
if args.verbose:
print("saving ORT result as numpy array at location : {args.save_ort}")
np.save(args.save_ort, pred_fw)
except Exception as e: except Exception as e:
if any(input_dims): if any(input_dims):
print( print(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment