Unverified Commit 7f97b8ef authored by Ted Themistokleous's avatar Ted Themistokleous Committed by GitHub
Browse files

Merge branch 'simplify_1_mul_div_ops' into divide_by_zero_check

parents 2ba401f0 d1fed367
......@@ -184,6 +184,12 @@ struct value
{
}
explicit binary(std::size_t s) : base(s) {}
friend std::ostream& operator<<(std::ostream& os, const binary& obj)
{
os << "{binary_object: " << obj.size() << "}";
return os;
}
};
value() = default;
......
......@@ -40,6 +40,12 @@ static void update_op(const instruction_ref& input, const instruction_ref& ins,
auto val = op.to_value();
auto op_padding = val.at("padding").to_vector<size_t>();
// skip if shape is dynamic
if(input->get_shape().dynamic())
{
return;
}
auto kdims = input->get_shape().lens().size() - 2;
if(std::equal(op_padding.begin(),
op_padding.begin() + kdims,
......
......@@ -176,13 +176,13 @@ bool operator==(const instruction& x, const instruction& y)
return true;
}
bool operator!=(const instruction& x, const instruction& y) { return !(x == y); }
bool operator!=(const instruction& x, const instruction& y) { return not(x == y); }
bool operator==(instruction_ref ref, const instruction& i) { return i == ref; }
bool operator!=(const instruction& i, instruction_ref ref) { return !(i == ref); }
bool operator!=(const instruction& i, instruction_ref ref) { return not(i == ref); }
bool operator!=(instruction_ref ref, const instruction& i) { return !(i == ref); }
bool operator!=(instruction_ref ref, const instruction& i) { return not(i == ref); }
void instruction::add_output(instruction_ref ins)
{
......@@ -361,7 +361,7 @@ void instruction::print(std::ostream& os,
os << "{" << ins->get_literal() << "}";
}
if(!ins->inputs().empty())
if(not ins->inputs().empty())
{
char delim = '(';
for(auto&& arg : ins->inputs())
......@@ -374,7 +374,7 @@ void instruction::print(std::ostream& os,
}
// print module inputs
if(!ins->module_inputs().empty())
if(not ins->module_inputs().empty())
{
std::string delim = ", [";
for(auto&& mod_arg : ins->module_inputs())
......@@ -445,8 +445,8 @@ operation instruction::normalized_operator() const
operation o = this->get_operator();
if(this->need_normalization())
{
auto lens = this->inputs().front()->get_shape().lens();
if(!normalize_attributes(o, lens))
auto s = this->inputs().front()->get_shape();
if(not normalize_attributes(o, s.max_lens()))
return this->get_operator();
}
return o;
......
......@@ -64,5 +64,10 @@ operation make_op_from_value(const std::string& name, const value& v)
});
}
operation make_json_op(const std::string& name, const std::string& s)
{
return make_op(name, from_json_string(convert_to_json(s)));
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -141,12 +141,12 @@ void module::set_bypass(bool b) { impl->bypass = b; }
void module::assign(const module& m)
{
// copy the impl
if(!impl)
if(not impl)
impl = std::make_unique<module_impl>();
*impl = *m.impl;
// clear instructions
if(!impl->instructions.empty())
if(not impl->instructions.empty())
{
impl->clear();
}
......@@ -357,7 +357,7 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref
assert(out->valid(begin()));
}
// Replacement should not be dead code unless its the last instruction
assert(!rep->outputs().empty() or rep == std::prev(end()));
assert(not rep->outputs().empty() or rep == std::prev(end()));
// Output of the original instruction should only be the replacement or empty
assert(ins->outputs().empty() or std::all_of(ins->outputs().begin(),
ins->outputs().end(),
......@@ -396,9 +396,13 @@ instruction_ref module::move_instruction(instruction_ref src, instruction_ref ds
instruction_ref module::move_instructions(instruction_ref src, instruction_ref dst)
{
this->move_instruction(src, dst);
for(auto ins : src->inputs())
this->move_instruction(ins, src);
{
if(not contains(this->impl->instructions, ins))
continue;
this->move_instructions(ins, dst);
}
this->move_instruction(src, dst);
return src;
}
......@@ -623,7 +627,7 @@ instruction_ref module::validate() const
auto inputs = i.inputs();
bool check_order = std::all_of(
inputs.begin(), inputs.end(), [&](auto in) { return has_instruction(in); });
return !i.valid(impl->instructions.begin(), check_order);
return not i.valid(impl->instructions.begin(), check_order);
});
}
......@@ -788,7 +792,7 @@ void module::print_graph(std::ostream& os, bool brief) const
label = to_string(ins->get_operator());
os << "\t" << enclose_name(ins_names.at(ins)) << "[label=" << enclose_name(label) << "]";
os << ";" << std::endl;
if(!ins->inputs().empty())
if(not ins->inputs().empty())
{
for(auto&& arg : ins->inputs())
{
......@@ -822,12 +826,15 @@ static std::string cpp_var_name(const std::string& name)
static void print_make_op(std::ostream& os, const operation& op)
{
os << "migraphx::make_op(" << enclose_name(op.name());
auto v = op.to_value();
if(not v.empty())
{
os << ", "
<< "migraphx::from_json_string(" << enclose_name(to_json_string(v)) << ")";
os << "migraphx::make_json_op(" << enclose_name(op.name());
os << ", " << enclose_name(to_json_string(v));
}
else
{
os << "migraphx::make_op(" << enclose_name(op.name());
}
os << ")";
}
......@@ -939,7 +946,7 @@ module& module::sort()
this->move_instruction(ins, this->begin());
for(auto child : ins->inputs())
{
if(!contains(this->impl->instructions, child))
if(not contains(this->impl->instructions, child))
{
continue;
}
......
......@@ -79,14 +79,14 @@ auto tune_attribute(const std::vector<int64_t>& vec,
{
if(contains(vec_attrs, op::normalize_attribute::include_max))
{
if(!std::equal(result.begin(), result.end(), max_vals.begin(), std::less_equal<>{}))
if(not std::equal(result.begin(), result.end(), max_vals.begin(), std::less_equal<>{}))
{
MIGRAPHX_THROW("TUNE_VECTOR: value out of range!");
}
}
else
{
if(!std::equal(result.begin(), result.end(), max_vals.begin(), std::less<>{}))
if(not std::equal(result.begin(), result.end(), max_vals.begin(), std::less<>{}))
{
MIGRAPHX_THROW("TUNE_VECTOR: value out of range!");
}
......@@ -118,14 +118,15 @@ auto tune_attribute(const std::vector<int64_t>& vec,
{
if(contains(vec_attrs, op::normalize_attribute::include_min))
{
if(!std::equal(min_vals.begin(), min_vals.end(), result.begin(), std::less_equal<>{}))
if(not std::equal(
min_vals.begin(), min_vals.end(), result.begin(), std::less_equal<>{}))
{
MIGRAPHX_THROW("TUNE_VECTOR: attribute out of range!");
}
}
else
{
if(!std::equal(result.begin(), result.end(), min_vals.begin(), std::less<>{}))
if(not std::equal(result.begin(), result.end(), min_vals.begin(), std::less<>{}))
{
MIGRAPHX_THROW("TUNE_VECTOR: attribute out of range!");
}
......@@ -174,7 +175,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
tuned = true;
}
}
if(!attrs.contains("normalize_axes"))
if(not attrs.contains("normalize_axes"))
{
return tuned;
}
......
......@@ -43,9 +43,9 @@ void normalize_ops::apply(module& m) const
if(inputs.empty())
continue;
auto lens = inputs[0]->get_shape().lens();
auto s = inputs[0]->get_shape();
migraphx::operation tuned_op = ins->get_operator();
if(normalize_attributes(tuned_op, lens))
if(normalize_attributes(tuned_op, s.max_lens()))
{
m.replace_instruction(ins, tuned_op, inputs);
ins->set_normalized();
......
......@@ -97,6 +97,7 @@ struct onnx_parser
shape::dynamic_dimension default_dyn_dim_value = {1, 1, 0};
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims;
bool use_dyn_output = false;
bool skip_unknown_operators = false;
int64_t max_loop_iterations = 10;
int64_t opset_version = 13;
......@@ -119,6 +120,7 @@ struct onnx_parser
};
shape::type_t get_type(int dtype);
bool is_type_float(shape::type_t dtype);
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -60,8 +60,14 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
{
parser.default_dyn_dim_value = options.default_dyn_dim_value;
}
if(not options.map_input_dims.empty() and not options.map_dyn_input_dims.empty())
{
MIGRAPHX_THROW("PARSE_ONNX_FROM: both map_input_dims and map_dyn_input_dims non-empty, only"
"one should be used");
}
parser.skip_unknown_operators = options.skip_unknown_operators;
parser.max_loop_iterations = options.max_loop_iterations;
parser.use_dyn_output = options.use_dyn_output;
if(options.print_program_on_error)
{
......@@ -80,6 +86,7 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
{
parser.parse_from(std::forward<Ts>(xs)...);
}
return std::move(parser.prog);
}
......
......@@ -28,7 +28,6 @@
#include <migraphx/stringutils.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pad_calc.hpp>
#include <migraphx/common.hpp>
#include <migraphx/type_traits.hpp>
#include <migraphx/float_equal.hpp>
......@@ -60,7 +59,7 @@ create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, const
std::accumulate(dims.begin(), dims.end(), std::size_t(1), std::multiplies<std::size_t>());
if(elem_num == 0)
{
return {};
return literal{shape_type};
}
// in case of scalar constants in onnx file, use dims=1 to fill initializer data
......@@ -77,7 +76,7 @@ static literal create_literal(shape::type_t shape_type, const std::vector<size_t
std::accumulate(dims.begin(), dims.end(), std::size_t(1), std::multiplies<std::size_t>());
if(elem_num == 0)
{
return {};
return literal{shape_type};
}
// scalar input
......@@ -188,7 +187,7 @@ operation onnx_parser::load(const std::string& name, const node_info& info) cons
void onnx_parser::parse_undefined(module* mod, const std::string& name)
{
if(!contains(instructions, name))
if(not contains(instructions, name))
{
auto ins = mod->add_instruction(make_op("undefined"));
instructions[name] = ins;
......@@ -257,11 +256,6 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model)
void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
{
if(not map_input_dims.empty() and not map_dyn_input_dims.empty())
{
MIGRAPHX_THROW("PARSE_GRAPH: both map_input_dims and map_dyn_input_dims non-empty, only"
"one should be used");
}
std::unordered_map<std::string, instruction_ref> mod_insts;
for(auto&& f : graph.initializer())
{
......@@ -273,7 +267,7 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
{
const std::string& name = input.name();
// input not in initializer_data, so it is a real input
if(!contains(mod_insts, name))
if(not contains(mod_insts, name))
{
// ONNX specification does not specify how to deal with the
// scenario that a nested subgraph contains a parameter with the
......@@ -360,7 +354,7 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
all_output_names.begin(),
all_output_names.end(),
std::back_inserter(prog_output_names),
[&](const auto& name) { return !(name.empty() or instructions.count(name) == 0); });
[&](const auto& name) { return not(name.empty() or instructions.count(name) == 0); });
std::vector<instruction_ref> output_ins;
std::transform(prog_output_names.begin(),
......@@ -450,7 +444,7 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
const std::vector<std::size_t>& input_dims) const
{
shape::type_t shape_type = get_type(t.tensor_type().elem_type());
if(!input_dims.empty())
if(not input_dims.empty())
{
return {shape_type, input_dims};
}
......@@ -514,6 +508,16 @@ shape::type_t get_type(int dtype)
}
}
bool is_type_float(shape::type_t dtype)
{
bool r = false;
if(dtype == shape::float_type or dtype == shape::double_type or dtype == shape::half_type)
{
r = true;
}
return r;
}
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -42,7 +42,7 @@ void cal_auto_padding_size(onnx_parser::node_info info,
size_t kdims = in_lens.size() - 2;
assert(k_lens.size() == kdims and dilation.size() == kdims);
if(!contains(info.attributes, "auto_pad"))
if(not contains(info.attributes, "auto_pad"))
{
return;
}
......@@ -124,7 +124,7 @@ void tune_padding_size(const value& v,
}
// if padding is symmetric, return directly
if(!is_asym_padding(padding))
if(not is_asym_padding(padding))
{
return;
}
......
......@@ -24,7 +24,7 @@
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/batch_norm_inference.hpp>
#include <migraphx/instruction.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -36,28 +36,63 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
const std::vector<instruction_ref>& args) const
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
float epsilon = 1e-5f;
float momentum = 0.9f;
op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::spatial;
float epsilon = 1e-5f;
if(contains(info.attributes, "epsilon"))
{
epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>();
}
if(contains(info.attributes, "momentum"))
auto x_lens = args[0]->get_shape().lens();
auto x_type = args[0]->get_shape().type();
if(std::any_of(args.cbegin() + 1, args.cend(), [](auto a) {
return a->get_shape().lens().size() != 1;
}))
{
MIGRAPHX_THROW("PARSE_BATCHNORM: argument scale, bias, mean, or var rank != 1");
}
if(x_lens.size() == 1)
{
auto rt = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {0.5}});
auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}});
auto n0 = info.add_broadcastable_binary_op("sub", args[0], args[3]);
auto d0 = info.add_broadcastable_binary_op("add", args[4], eps);
auto d1 = info.add_broadcastable_binary_op("pow", d0, rt);
auto div0 = info.add_broadcastable_binary_op("div", n0, d1);
auto r0 = info.add_broadcastable_binary_op("mul", div0, args[1]);
return info.add_broadcastable_binary_op("add", r0, args[2]);
}
else if(x_lens.size() > 2)
{
momentum = parser.parse_value(info.attributes.at("momentum")).at<float>();
// unsqueeze tensors of shape (C) to broadcast correctly
std::vector<int64_t> unsqueeze_axes(x_lens.size() - 2);
std::iota(unsqueeze_axes.begin(), unsqueeze_axes.end(), 1);
auto rt = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {0.5}});
auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}});
auto scale_unsqueeze = info.add_instruction(
migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[1]);
auto bias_unsqueeze = info.add_instruction(
migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[2]);
auto mean_unsqueeze = info.add_instruction(
migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[3]);
auto var_unsqueeze = info.add_instruction(
migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[4]);
auto numer = info.add_broadcastable_binary_op("sub", args[0], mean_unsqueeze);
auto var_eps = info.add_broadcastable_binary_op("add", var_unsqueeze, eps);
auto denom = info.add_broadcastable_binary_op("pow", var_eps, rt);
auto div0 = info.add_broadcastable_binary_op("div", numer, denom);
auto r0 = info.add_broadcastable_binary_op("mul", div0, scale_unsqueeze);
return info.add_broadcastable_binary_op("add", r0, bias_unsqueeze);
}
if(contains(info.attributes, "spatial"))
else
{
bn_mode = (parser.parse_value(info.attributes.at("spatial")).at<uint64_t>() > 0)
? op::batch_norm_inference::spatial
: op::batch_norm_inference::per_activation;
// num dims either 0 or 2
MIGRAPHX_THROW("PARSE_BATCHNORM: rank " + std::to_string(x_lens.size()) +
" input tensor, unhandled data format");
}
op::batch_norm_inference op{epsilon, momentum, bn_mode};
return info.add_instruction(op, args);
}
};
......
......@@ -38,7 +38,7 @@ struct parse_cast : op_parser<parse_cast>
onnx_parser::node_info info,
const std::vector<instruction_ref>& args) const
{
if(!contains(info.attributes, "to"))
if(not contains(info.attributes, "to"))
{
MIGRAPHX_THROW("PARSE_CAST: missing to type attribute!");
}
......
......@@ -43,7 +43,7 @@ struct parse_constant : op_parser<parse_constant>
// return empty literal
if(v.get_shape().elements() == 0)
{
return info.add_literal(literal{});
return info.add_literal(literal{v.get_shape().type()});
}
auto dim_size = info.attributes.at("value").t().dims_size();
......
......@@ -93,7 +93,7 @@ struct parse_constant_fill : op_parser<parse_constant_fill>
}
else if(input_as_shape == 0)
{
if(!contains(info.attributes, "shape"))
if(not contains(info.attributes, "shape"))
{
MIGRAPHX_THROW("ConstantFill: attribute output shape is needed");
}
......
......@@ -47,15 +47,17 @@ struct parse_convolution : op_parser<parse_convolution>
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
auto op = make_op(opd.op_name);
auto values = op.to_value();
auto l0 = args[0];
auto weights = args[1];
auto in_lens = l0->get_shape().lens();
auto op = make_op(opd.op_name);
auto values = op.to_value();
auto l0 = args[0];
auto weights = args[1];
auto l0_shape = l0->get_shape();
auto w_shape = weights->get_shape();
auto in_lens = l0_shape.max_lens();
assert(in_lens.size() > 2);
auto kdims = in_lens.size() - 2;
// ensure pads availabe only when auto_pad is "NOT_SET"
// ensure pads available only when auto_pad is "NOT_SET"
check_padding_mode(info, "CONV");
if(contains(info.attributes, "strides"))
......@@ -79,21 +81,65 @@ struct parse_convolution : op_parser<parse_convolution>
copy(info.attributes["pads"].ints(), std::back_inserter(padding));
check_attr_sizes(kdims, padding.size() / 2, "PARSE_CONV: inconsistent paddings");
}
if(contains(info.attributes, "auto_pad"))
{
auto weight_lens = weights->get_shape().lens();
std::vector<std::size_t> k_lens(weight_lens.begin() + 2, weight_lens.end());
cal_auto_padding_size(info,
values,
k_lens,
values["dilation"].to_vector<std::size_t>(),
in_lens,
padding);
auto auto_pad = info.attributes["auto_pad"].s();
bool is_same_padding = false;
auto auto_pad = info.attributes["auto_pad"].s();
if(auto_pad.find("SAME") != std::string::npos)
{
values["padding_mode"] = to_value(op::padding_mode_t::same);
is_same_padding = true;
}
// check if image shape is dynamic
bool image_shape_dynamic = false;
if(l0_shape.dynamic())
{
auto dyn_dims = l0_shape.dyn_dims();
std::for_each(dyn_dims.begin() + 2, dyn_dims.end(), [&](auto dyn_dim) {
if(not dyn_dim.is_fixed())
{
image_shape_dynamic = true;
}
});
}
// check if kernel shape is dynamic
bool kernel_shape_dynamic = false;
if(w_shape.dynamic())
{
auto dyn_dims = w_shape.dyn_dims();
std::for_each(dyn_dims.begin() + 2, dyn_dims.end(), [&](auto dyn_dim) {
if(not dyn_dim.is_fixed())
{
kernel_shape_dynamic = true;
}
});
}
if(is_same_padding)
{
if(image_shape_dynamic or kernel_shape_dynamic)
{
// must calculate "same" padding with input shape data
bool is_same_upper = (auto_pad.find("SAME_UPPER") != std::string::npos);
values["padding_mode"] = is_same_upper
? to_value(op::padding_mode_t::same_upper)
: to_value(op::padding_mode_t::same_lower);
values["use_dynamic_same_auto_pad"] = true;
}
else
{
values["padding_mode"] = to_value(op::padding_mode_t::same);
// kernel shape will be fixed, so max_lens() == min_len() for kernel lengths
auto weight_lens = weights->get_shape().max_lens();
std::vector<std::size_t> k_lens(weight_lens.begin() + 2, weight_lens.end());
cal_auto_padding_size(info,
values,
k_lens,
values["dilation"].to_vector<std::size_t>(),
in_lens,
padding);
}
}
}
values["padding"] = std::vector<size_t>(padding.begin(), padding.end());
......
......@@ -94,7 +94,7 @@ struct parse_gemm : op_parser<parse_gemm>
out_lens.back() = l2->get_shape().lens().back();
auto l3 = args[2];
auto l3_lens = l3->get_shape().lens();
if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
if(not std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
{
l3 = info.add_instruction(make_op("multibroadcast", {{"out_lens", out_lens}}),
args[2]);
......
......@@ -58,7 +58,6 @@ struct parse_generic_op : op_parser<parse_generic_op>
{"Log", "log"},
{"LRN", "lrn"},
{"Neg", "neg"},
{"NonMaxSuppression", "nonmaxsuppression"},
{"Reciprocal", "recip"},
{"Relu", "relu"},
{"Round", "round"},
......@@ -75,7 +74,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
bool needs_contiguous(const std::string& op_name) const
{
return contains({"flatten", "gather", "nonmaxsuppression", "scatter"}, op_name);
return contains({"flatten", "gather", "scatter"}, op_name);
}
instruction_ref parse(const op_desc& opd,
......
......@@ -47,7 +47,8 @@ struct parse_if : op_parser<parse_if>
if(args.front()->get_shape().elements() != 1)
{
MIGRAPHX_THROW("PARSE_IF: condition input can have only one element!");
MIGRAPHX_THROW("PARSE_IF: " + info.name +
" condition input can have only one element!");
}
std::string then_name = info.name + "_if";
......@@ -69,7 +70,8 @@ struct parse_if : op_parser<parse_if>
else_out_shapes.begin(),
else_out_shapes.end()))
{
MIGRAPHX_THROW("PARSE_IF: then and else sub_grahps must have same output shapes!");
MIGRAPHX_THROW("PARSE_IF: " + info.name +
" then and else sub_grahps must have same output shapes!");
}
auto if_ret = info.add_instruction(make_op("if"), args, {then_mdl, else_mdl});
......
......@@ -32,9 +32,12 @@ namespace onnx {
struct parse_instancenorm : op_parser<parse_instancenorm>
{
const std::set<shape::type_t> valid_types = {
shape::float_type, shape::half_type, shape::double_type};
std::vector<op_desc> operators() const { return {{"InstanceNormalization"}}; }
instruction_ref parse(const op_desc& /*opd*/,
instruction_ref parse(const op_desc& opd,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
......@@ -52,6 +55,11 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
auto scale = args[1];
auto bias = args[2];
auto dims = x->get_shape().lens();
auto dtype = x->get_shape().type();
if(not contains(valid_types, dtype))
MIGRAPHX_THROW(opd.op_name + ": invalid output type: " + std::to_string(dtype) +
". Valid types are 1 (float), 10 (half), and 11 (double).");
auto ndims = dims.size();
assert(ndims >= 2);
auto kdims = ndims - 2;
......@@ -65,7 +73,7 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
auto l0 = info.add_instruction(make_op("sqdiff"), x, mean_bcast);
auto variance = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), l0);
auto l1 = info.add_instruction(make_op("sub"), x, mean_bcast);
auto epsilon_literal = info.add_literal(epsilon);
auto epsilon_literal = info.add_literal(literal{shape{dtype}, {epsilon}});
auto epsilon_bcast =
info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), epsilon_literal);
auto variance_bcast =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment