Commit 0420b20d authored by Khalique's avatar Khalique
Browse files

modifications to operators

parent 52df0ca8
...@@ -177,13 +177,13 @@ struct tf_parser ...@@ -177,13 +177,13 @@ struct tf_parser
add_mem_op("GatherV2", &tf_parser::parse_gather, false); add_mem_op("GatherV2", &tf_parser::parse_gather, false);
add_mem_op("MatMul", &tf_parser::parse_matmul, false); add_mem_op("MatMul", &tf_parser::parse_matmul, false);
add_mem_op("MaxPool", &tf_parser::parse_pooling); add_mem_op("MaxPool", &tf_parser::parse_pooling);
add_mem_op("Mean", &tf_parser::parse_mean); add_mem_op("Mean", &tf_parser::parse_mean, false);
add_mem_op("OneHot", &tf_parser::parse_onehot, false); add_mem_op("OneHot", &tf_parser::parse_onehot, false);
add_mem_op("Pack", &tf_parser::parse_pack, false); add_mem_op("Pack", &tf_parser::parse_pack, false);
add_mem_op("Pad", &tf_parser::parse_pad); add_mem_op("Pad", &tf_parser::parse_pad);
add_mem_op("Reshape", &tf_parser::parse_reshape, false); add_mem_op("Reshape", &tf_parser::parse_reshape, false);
add_mem_op("Slice", &tf_parser::parse_slice, false); add_mem_op("Slice", &tf_parser::parse_slice, false);
add_mem_op("Softmax", &tf_parser::parse_softmax); add_mem_op("Softmax", &tf_parser::parse_softmax<op::softmax>, false);
add_mem_op("Squeeze", &tf_parser::parse_squeeze, false); add_mem_op("Squeeze", &tf_parser::parse_squeeze, false);
add_mem_op("StridedSlice", &tf_parser::parse_stridedslice, false); add_mem_op("StridedSlice", &tf_parser::parse_stridedslice, false);
add_mem_op("Transpose", &tf_parser::parse_transpose, false); add_mem_op("Transpose", &tf_parser::parse_transpose, false);
...@@ -548,7 +548,7 @@ struct tf_parser ...@@ -548,7 +548,7 @@ struct tf_parser
} }
if(contains(attributes, "transpose_b")) if(contains(attributes, "transpose_b"))
{ {
transb = attributes.at("transpose_a").b(); transb = attributes.at("transpose_b").b();
} }
if(contains(attributes, "adj_x")) if(contains(attributes, "adj_x"))
...@@ -575,23 +575,24 @@ struct tf_parser ...@@ -575,23 +575,24 @@ struct tf_parser
parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
bool keep_dims = attributes.at("keep_dims").b(); bool keep_dims = attributes.at("keep_dims").b();
std::vector<int32_t> hw_axes{2, 3}; // std::vector<int32_t> hw_axes{2, 3};
// check if conditions for GlobalAvgPool are met // check if conditions for GlobalAvgPool are met
auto lens = args[0]->get_shape().lens(); auto lens = args[0]->get_shape().lens();
auto axes = parse_axes(args[1]->eval().get<int32_t>().to_vector(), lens.size()); auto axes = args[1]->eval().get<int32_t>().to_vector();
std::vector<int64_t> axes_int64 = std::vector<int64_t>(axes.begin(), axes.end());
if(axes == hw_axes and lens.size() == 4)
{ // if(axes == hw_axes and lens.size() == 4)
op::pooling op{"average"}; // {
op.lengths[0] = lens[2]; // op::pooling op{"average"};
op.lengths[1] = lens[3]; // op.lengths[0] = lens[2];
auto l0 = prog.add_instruction(op, args.front()); // op.lengths[1] = lens[3];
if(keep_dims) auto l0 = prog.add_instruction(op::reduce_mean{axes_int64}, args.front());
return l0; if(keep_dims)
return prog.add_instruction( return l0;
op::squeeze{std::vector<int64_t>(hw_axes.begin(), hw_axes.end())}, l0); return prog.add_instruction(
} op::squeeze{axes_int64}, l0);
MIGRAPHX_THROW("MIGraphX does not support mean outside of GlobalAvgPool transformation"); // }
// MIGRAPHX_THROW("MIGraphX does not support mean outside of GlobalAvgPool transformation");
} }
instruction_ref instruction_ref
...@@ -763,14 +764,24 @@ struct tf_parser ...@@ -763,14 +764,24 @@ struct tf_parser
} }
} }
instruction_ref // template to facilitate the logsoftmax later
parse_softmax(const std::string&, const attribute_map&, std::vector<instruction_ref> args) template <class Op>
instruction_ref parse_softmax(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
{ {
auto dims = args.front()->get_shape().lens(); int axis = -1;
auto r = auto num_dims = args[0]->get_shape().lens().size();
prog.add_instruction(op::reshape{{long(dims[0]), long(dims[1]), 1, 1}}, args.front()); if(contains(attributes, "axis"))
auto s = prog.add_instruction(op::softmax{}, r); {
return prog.add_instruction(op::reshape{{long(dims[0]), long(dims[1])}}, s); axis = static_cast<int>(attributes.at("axis").i());
}
if(axis < 0)
{
axis += num_dims;
}
return prog.add_instruction(Op{axis}, make_contiguous(args[0]));
} }
instruction_ref parse_squeeze(const std::string&, instruction_ref parse_squeeze(const std::string&,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment