".github/vscode:/vscode.git/clone" did not exist on "9768e2dc7574c36608bb04ac39a3b79e639a837f"
Commit e814cffb authored by kahmed10's avatar kahmed10 Committed by mvermeulen
Browse files

Add Split op (#401)

* fix pad calc

* simplify ceil calc and remove extra vars

* workatound for nasnet

* formatting

* add split and tests

* formatting

* fix cppcheck and clang-tidy

* fix clang tidy

* refactor to use vector of instruction_ref, add UNDEBUG to clang tidy

* formatting

* fix comment code

* fix comments and tidy

* formatting

* fix syntax error

* fix tidy

* remove namespace comment
parent 6e6e7f3a
......@@ -110,6 +110,7 @@ rocm_enable_clang_tidy(
HEADER_FILTER
".*hpp"
EXTRA_ARGS
-UNDEBUG
-DMIGRAPHX_USE_CLANG_TIDY
"-Dmain\\\\(...\\\\)=main\\\\(__VA_ARGS__\\\\) // NOLINT"
# CLANG_ARGS
......
......@@ -26,7 +26,8 @@ struct tf_parser
{
using attribute_map = std::unordered_map<std::string, tensorflow::AttrValue>;
using node_map = std::map<std::string, tensorflow::NodeDef>;
using op_func = std::function<instruction_ref(attribute_map, std::vector<instruction_ref>)>;
using op_func =
std::function<std::vector<instruction_ref>(attribute_map, std::vector<instruction_ref>)>;
node_map nodes;
std::vector<tensorflow::NodeDef> input_nodes;
......@@ -78,6 +79,14 @@ struct tf_parser
return result;
}
std::vector<instruction_ref> to_nhwc(const std::vector<instruction_ref>& args)
{
std::vector<instruction_ref> result(args.size());
std::transform(
args.begin(), args.end(), result.begin(), [&](auto ins) { return this->to_nhwc(ins); });
return result;
}
std::vector<size_t>
parse_axes(const attribute_map& attributes, const std::string& s, const size_t num_dims) const
{
......@@ -200,6 +209,8 @@ struct tf_parser
add_mem_op("Pad", &tf_parser::parse_pad);
add_mem_op("Reshape", &tf_parser::parse_reshape, false);
add_mem_op("Slice", &tf_parser::parse_slice, false);
add_mem_op("Split", &tf_parser::parse_split, false);
add_mem_op("SplitV", &tf_parser::parse_split, false);
add_mem_op("Softmax", &tf_parser::parse_softmax<op::softmax>, false);
add_mem_op("Squeeze", &tf_parser::parse_squeeze, false);
add_mem_op("StridedSlice", &tf_parser::parse_stridedslice, false);
......@@ -207,19 +218,24 @@ struct tf_parser
}
template <class F>
void add_op(std::string name, F f, bool transpose = true)
void add_op(const std::string& name, F f, bool transpose = true)
{
if(transpose)
{
ops.emplace(name,
op_func{[=](const attribute_map& attributes,
const std::vector<instruction_ref>& args) -> instruction_ref {
return to_nhwc(f(attributes, to_nchw(args)));
}});
ops.emplace(
name,
op_func{
[=](const attribute_map& attributes, const std::vector<instruction_ref>& args) {
return std::vector<instruction_ref>{to_nhwc(f(attributes, to_nchw(args)))};
}});
}
else
{
ops.emplace(name, f);
ops.emplace(name,
op_func{[=](const attribute_map& attributes,
const std::vector<instruction_ref>& args) {
return std::vector<instruction_ref>{f(attributes, args)};
}});
}
}
......@@ -809,6 +825,84 @@ struct tf_parser
return prog.add_instruction(Op{axis}, make_contiguous(args[0]));
}
std::vector<instruction_ref> parse_split(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
bool vector_as_input = args.size() == 3;
int num_outputs = 1;
auto axis_arg = args[0];
auto input_arg = args[1];
if(vector_as_input)
{
input_arg = args[0];
axis_arg = args[2];
}
if(contains(attributes, "num_split"))
num_outputs = attributes.at("num_split").i();
std::vector<int> splits(num_outputs);
std::vector<int> slice_pos{0};
if(vector_as_input)
{
splits = args[1]->eval().get<int32_t>().to_vector();
num_outputs = splits.size();
}
assert(num_outputs > 0);
if(num_outputs == 1)
return std::vector<instruction_ref>{prog.add_instruction(op::identity{}, input_arg)};
auto lens = input_arg->get_shape().lens();
auto num_dims = lens.size();
int axis = axis_arg->eval().at<int32_t>();
// ensure split is made evenly if "num_split" is used
assert(vector_as_input or lens[axis] % num_outputs == 0);
auto split_size = lens[axis] / num_outputs;
// push back first end point of slice
if(vector_as_input)
{
slice_pos.push_back(splits[0]);
}
else
{
slice_pos.push_back(split_size);
}
// calculate remaining end points for each slice
for(auto i = 1; i < num_outputs; i++)
{
if(vector_as_input)
{
splits[i] += splits[i - 1];
slice_pos.push_back(splits[i]);
}
else
{
slice_pos.push_back((i + 1) * split_size);
}
}
std::vector<instruction_ref> result;
for(auto i = 0; i < num_outputs; i++)
{
op::slice op;
op.axes = std::vector<int64_t>(num_dims);
std::iota(op.axes.begin(), op.axes.end(), 0);
op.starts = std::vector<int64_t>(num_dims, 0);
op.ends = std::vector<int64_t>(lens.begin(), lens.end());
op.starts[axis] = slice_pos[i];
op.ends[axis] = slice_pos[i + 1];
result.push_back(prog.add_instruction(op, input_arg));
}
return result;
}
instruction_ref parse_squeeze(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
......@@ -939,23 +1033,42 @@ struct tf_parser
continue;
if(nodes.count(input) > 0)
{
auto&& iname = get_name(nodes.at(input));
std::string iname;
// input was from a node with multiple outputs
if(contains(input, ':'))
{
iname = input.substr(0, input.find(':'));
}
else
{
iname = get_name(nodes.at(input));
}
assert(name != iname);
this->parse_node(iname);
args.push_back(instructions.at(iname));
args.push_back(instructions.at(input));
}
else
{
args.push_back(instructions.at(input));
}
}
std::vector<instruction_ref> result;
if(ops.count(node.op()) == 0)
{
instructions[name] = prog.add_instruction(op::unknown{node.op()}, args);
result.push_back(prog.add_instruction(op::unknown{node.op()}, args));
}
else
{
instructions[name] = ops[node.op()](get_attributes(node), args);
result = ops[node.op()](get_attributes(node), args);
}
assert(!result.empty());
// First output has no ":" delimiter
instructions[name] = result.front();
for(size_t i = 1; i < result.size(); i++)
{
instructions[name + ":" + std::to_string(i)] = result.at(i);
}
}
}
......
......@@ -273,6 +273,33 @@ def softmax_test(g1):
tf.nn.softmax(g1_input, name='softmax')
@tf_test
def split_test(g1):
with g1.as_default():
g1_input = tf.placeholder(tf.float32, shape=(5, 30), name='0')
split0, split1, split2 = tf.split(g1_input, 3, 1, name='split')
tf.concat([split0, split1], axis=1, name='concat1')
tf.concat([split1, split2], axis=1, name='concat2')
@tf_test
def split_test_one_output(g1):
with g1.as_default():
g1_input = tf.placeholder(tf.float32, shape=(5, 30), name='0')
tf.split(g1_input, 1, 1, name='split')
@tf_test
def split_test_vector_as_input(g1):
with g1.as_default():
g1_input = tf.placeholder(tf.float32, shape=(5, 30), name='0')
split0, split1, split2 = tf.split(g1_input, [4, 15, 11],
1,
name='split')
tf.concat([split0, split1], axis=1, name='concat1')
tf.concat([split1, split2], axis=1, name='concat2')
@tf_test
def sqdiff_test(g1):
with g1.as_default():
......
......@@ -474,6 +474,61 @@ TEST_CASE(softmax_test)
EXPECT(p == prog);
}
TEST_CASE(split_test)
{
migraphx::program p;
std::vector<int64_t> axes{0, 1};
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 30}});
p.add_literal(3); // num_splits
p.add_literal(1); // split axis
p.add_literal(1); // concat axis
p.add_literal(1); // concat axis
auto l1 = p.add_instruction(migraphx::op::slice{axes, {0, 0}, {5, 10}}, l0);
auto l2 = p.add_instruction(migraphx::op::slice{axes, {0, 10}, {5, 20}}, l0);
auto l3 = p.add_instruction(migraphx::op::slice{axes, {0, 20}, {5, 30}}, l0);
p.add_instruction(migraphx::op::concat{1}, l1, l2);
p.add_instruction(migraphx::op::concat{1}, l2, l3);
auto prog = migraphx::parse_tf("split_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(split_test_one_output)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 30}});
p.add_literal(1); // num_splits
p.add_literal(1); // split axis
p.add_instruction(migraphx::op::identity{}, l0);
auto prog = migraphx::parse_tf("split_test_one_output.pb", false);
EXPECT(p == prog);
}
TEST_CASE(split_test_vector_as_input)
{
migraphx::program p;
std::vector<int64_t> axes{0, 1};
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 30}});
// split sizes
p.add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {3}}, {4, 15, 11}});
p.add_literal(1); // split axis
p.add_literal(1); // concat axis
p.add_literal(1); // concat axis
auto l1 = p.add_instruction(migraphx::op::slice{axes, {0, 0}, {5, 4}}, l0);
auto l2 = p.add_instruction(migraphx::op::slice{axes, {0, 4}, {5, 19}}, l0);
auto l3 = p.add_instruction(migraphx::op::slice{axes, {0, 19}, {5, 30}}, l0);
p.add_instruction(migraphx::op::concat{1}, l1, l2);
p.add_instruction(migraphx::op::concat{1}, l2, l3);
auto prog = migraphx::parse_tf("split_test_vector_as_input.pb", false);
EXPECT(p == prog);
}
TEST_CASE(sqdiff_test)
{
migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment