Unverified Commit 4d59b7c7 authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Fix TF parsing for creating literals and Fix name lookups for input params (#1298)

Bug 1: create_literal was using back_inserter to copy vector with already allocated size, causing double the size of literal.
Fix 1 : not use back_inserter
Bug 2: Input param to model can be from operation that has multiple output, in that case name of the input param would contain : e.g. input_1:0
Fix 2: Look for : and take substring
parent 5a87fcbd
...@@ -216,7 +216,7 @@ static std::vector<T> get_data_vals(const google::protobuf::RepeatedField<T>& da ...@@ -216,7 +216,7 @@ static std::vector<T> get_data_vals(const google::protobuf::RepeatedField<T>& da
std::fill(data_vals.begin(), data_vals.end(), data[0]); std::fill(data_vals.begin(), data_vals.end(), data[0]);
} }
else else
copy(data.begin(), data.end(), std::back_inserter(data_vals)); copy(data.begin(), data.end(), data_vals.begin());
return data_vals; return data_vals;
} }
...@@ -329,33 +329,37 @@ void tf_parser::parse_node(const std::string& name) ...@@ -329,33 +329,37 @@ void tf_parser::parse_node(const std::string& name)
auto&& node = nodes.at(name); auto&& node = nodes.at(name);
if(not is_valid_op(node)) if(not is_valid_op(node))
return; return;
std::vector<instruction_ref> args; std::vector<instruction_ref> args;
for(auto&& input : node.input()) for(auto&& input : node.input())
{ {
// control dependencies (signified by ^ before the name) are ignored // control dependencies (signified by ^ before the name) are ignored
if(contains(input, "^")) if(contains(input, "^"))
continue; continue;
if(nodes.count(input) > 0) std::string input_name = input;
// if input has trailing `:0` index then remove it
auto multi_out_idx = input.find(':');
if(multi_out_idx != std::string::npos && input.substr(multi_out_idx + 1) == "0")
{
input_name = input.substr(0, multi_out_idx);
}
if(nodes.count(input_name) > 0)
{ {
std::string iname;
// input was from a node with multiple outputs // input was from a node with multiple outputs
if(contains(input, ':')) if(contains(input_name, ':'))
{ {
iname = input.substr(0, input.find(':')); input_name = input_name.substr(0, input.find(':'));
} }
else else
{ {
iname = get_name(nodes.at(input)); input_name = get_name(nodes.at(input_name));
} }
assert(name != iname); assert(name != input_name);
this->parse_node(iname); this->parse_node(input_name);
args.push_back(instructions.at(input)); args.push_back(instructions.at(input_name));
} }
else else
{ {
args.push_back(instructions.at(input)); args.push_back(instructions.at(input_name));
} }
} }
std::vector<instruction_ref> result; std::vector<instruction_ref> result;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment