"...targets/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "b9c140fe2141916eb7bf7c6fc7d3e0cd6c68bd8c"
Commit 0420b20d authored by Khalique's avatar Khalique
Browse files

modifications to operators

parent 52df0ca8
......@@ -177,13 +177,13 @@ struct tf_parser
add_mem_op("GatherV2", &tf_parser::parse_gather, false);
add_mem_op("MatMul", &tf_parser::parse_matmul, false);
add_mem_op("MaxPool", &tf_parser::parse_pooling);
add_mem_op("Mean", &tf_parser::parse_mean);
add_mem_op("Mean", &tf_parser::parse_mean, false);
add_mem_op("OneHot", &tf_parser::parse_onehot, false);
add_mem_op("Pack", &tf_parser::parse_pack, false);
add_mem_op("Pad", &tf_parser::parse_pad);
add_mem_op("Reshape", &tf_parser::parse_reshape, false);
add_mem_op("Slice", &tf_parser::parse_slice, false);
add_mem_op("Softmax", &tf_parser::parse_softmax);
add_mem_op("Softmax", &tf_parser::parse_softmax<op::softmax>, false);
add_mem_op("Squeeze", &tf_parser::parse_squeeze, false);
add_mem_op("StridedSlice", &tf_parser::parse_stridedslice, false);
add_mem_op("Transpose", &tf_parser::parse_transpose, false);
......@@ -548,7 +548,7 @@ struct tf_parser
}
if(contains(attributes, "transpose_b"))
{
transb = attributes.at("transpose_a").b();
transb = attributes.at("transpose_b").b();
}
if(contains(attributes, "adj_x"))
......@@ -575,23 +575,24 @@ struct tf_parser
parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
bool keep_dims = attributes.at("keep_dims").b();
std::vector<int32_t> hw_axes{2, 3};
// std::vector<int32_t> hw_axes{2, 3};
// check if conditions for GlobalAvgPool are met
auto lens = args[0]->get_shape().lens();
auto axes = parse_axes(args[1]->eval().get<int32_t>().to_vector(), lens.size());
if(axes == hw_axes and lens.size() == 4)
{
op::pooling op{"average"};
op.lengths[0] = lens[2];
op.lengths[1] = lens[3];
auto l0 = prog.add_instruction(op, args.front());
if(keep_dims)
return l0;
return prog.add_instruction(
op::squeeze{std::vector<int64_t>(hw_axes.begin(), hw_axes.end())}, l0);
}
MIGRAPHX_THROW("MIGraphX does not support mean outside of GlobalAvgPool transformation");
auto axes = args[1]->eval().get<int32_t>().to_vector();
std::vector<int64_t> axes_int64 = std::vector<int64_t>(axes.begin(), axes.end());
// if(axes == hw_axes and lens.size() == 4)
// {
// op::pooling op{"average"};
// op.lengths[0] = lens[2];
// op.lengths[1] = lens[3];
auto l0 = prog.add_instruction(op::reduce_mean{axes_int64}, args.front());
if(keep_dims)
return l0;
return prog.add_instruction(
op::squeeze{axes_int64}, l0);
// }
// MIGRAPHX_THROW("MIGraphX does not support mean outside of GlobalAvgPool transformation");
}
instruction_ref
......@@ -763,14 +764,24 @@ struct tf_parser
}
}
instruction_ref
parse_softmax(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
// template to facilitate the logsoftmax later
template <class Op>
instruction_ref parse_softmax(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
auto dims = args.front()->get_shape().lens();
auto r =
prog.add_instruction(op::reshape{{long(dims[0]), long(dims[1]), 1, 1}}, args.front());
auto s = prog.add_instruction(op::softmax{}, r);
return prog.add_instruction(op::reshape{{long(dims[0]), long(dims[1])}}, s);
int axis = -1;
auto num_dims = args[0]->get_shape().lens().size();
if(contains(attributes, "axis"))
{
axis = static_cast<int>(attributes.at("axis").i());
}
if(axis < 0)
{
axis += num_dims;
}
return prog.add_instruction(Op{axis}, make_contiguous(args[0]));
}
instruction_ref parse_squeeze(const std::string&,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment