Commit 0a6bc406 authored by umangyadav's avatar umangyadav
Browse files

Add logic to convert fp16 in parse instance norm, add flags for numpy in accuracy checker

parent 88fb551c
......@@ -39,22 +39,39 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
instruction_ref parse(const op_desc& opd,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
std::vector<instruction_ref> oargs) const
{
// y = scale * ( x - mean ) / sqrt ( variance + epsilon ) + bias
// mean = reduce_mean({D1, D2, ... Dk}, x)
// variance = reduce_mean({D1, D2, ... Dk}, (x - mean)^2)
float epsilon = 1e-5f;
bool convert_fp16 = true;
float epsilon = 1e-5f;
if(contains(info.attributes, "epsilon"))
{
epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>();
}
auto dtype = oargs[0]->get_shape().type();
auto literal_dtype = dtype;
std::vector<instruction_ref> args;
if(dtype == shape::half_type and convert_fp16)
{
args.push_back(info.add_instruction(
make_op("convert", {{"target_type", shape::float_type}}), oargs[0]));
args.push_back(info.add_instruction(
make_op("convert", {{"target_type", shape::float_type}}), oargs[1]));
args.push_back(info.add_instruction(
make_op("convert", {{"target_type", shape::float_type}}), oargs[2]));
literal_dtype = shape::float_type;
}
else
{
args = oargs;
}
auto x = args[0];
auto scale = args[1];
auto bias = args[2];
auto dims = x->get_shape().lens();
auto dtype = x->get_shape().type();
if(not contains(valid_types, dtype))
MIGRAPHX_THROW(opd.op_name + ": invalid output type: " + std::to_string(dtype) +
". Valid types are 1 (float), 10 (half), and 11 (double).");
......@@ -72,7 +89,7 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
auto l0 = info.add_instruction(make_op("sqdiff"), x, mean_bcast);
auto variance = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), l0);
auto l1 = info.add_instruction(make_op("sub"), x, mean_bcast);
auto epsilon_literal = info.add_literal(literal{shape{dtype}, {epsilon}});
auto epsilon_literal = info.add_literal(literal{shape{literal_dtype}, {epsilon}});
auto epsilon_bcast =
info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), epsilon_literal);
auto variance_bcast =
......@@ -85,8 +102,14 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
;
auto bias_bcast =
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), bias);
auto l5 = info.add_instruction(make_op("mul"), l4, scale_bcast);
return info.add_instruction(make_op("add"), l5, bias_bcast);
auto l5 = info.add_instruction(make_op("mul"), l4, scale_bcast);
auto ret = info.add_instruction(make_op("add"), l5, bias_bcast);
if(dtype == shape::half_type and convert_fp16)
{
return info.add_instruction(make_op("convert", {{"target_type", shape::half_type}}),
ret);
}
return ret;
}
};
......
......@@ -52,6 +52,8 @@ def parse_args():
parser.add_argument('--fill0',
action='store_true',
help='fill all arguments with a value of 0')
parser.add_argument('--numpy', action='append', help='fill argument with numpy saved array', type=str)
parser.add_argument('--np_path', action='append', help='Path for the saved numpy array', type=str)
parser.add_argument('--verbose',
action='store_true',
help='show verbose information (for debugging)')
......@@ -81,7 +83,7 @@ def parse_args():
action='store_true',
default=False,
help='Turn on ort VERBOSE logging via session options')
parser.add_argument('--save-ort-res', dest="save_ort", type=str, help='Save output of ORT as numpy array at path provided by this argument')
args = parser.parse_args()
return args
......@@ -186,7 +188,12 @@ def main():
print(f'Parameter {name} -> {shape}')
in_shape = shape.lens()
in_type = shape.type_string()
if not args.fill1 and not args.fill0:
if name in args.numpy:
in_path = args.np_path[args.numpy.index(name)]
if args.verbose:
print("Loading numpy array for input name : {name}, from path : {in_path}")
test_input = np.load(in_path).astype(get_np_datatype(in_type))
elif not args.fill1 and not args.fill0:
test_input = np.random.rand(*(in_shape)).astype(
get_np_datatype(in_type))
elif not args.fill0:
......@@ -216,6 +223,10 @@ def main():
try:
pred_fw = sess.run(None, ort_params)[-1]
if(args.save_ort):
if args.verbose:
print("saving ORT result as numpy array at location : {args.save_ort}")
np.save(args.save_ort, pred_fw)
except Exception as e:
if any(input_dims):
print(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment