Commit 6e058792 authored by Khalique's avatar Khalique
Browse files

formatting

parent c8a91e20
...@@ -164,7 +164,7 @@ struct tf_parser ...@@ -164,7 +164,7 @@ struct tf_parser
{ {
op.padding_mode = op::convolution::same; op.padding_mode = op::convolution::same;
} }
else if (pad_mode.find("EXPLICIT") != std::string::npos) else if(pad_mode.find("EXPLICIT") != std::string::npos)
{ {
std::vector<std::size_t> padding(4); std::vector<std::size_t> padding(4);
copy(attributes.at("explicit_paddings").list().i(), padding.begin()); copy(attributes.at("explicit_paddings").list().i(), padding.begin());
...@@ -308,19 +308,17 @@ struct tf_parser ...@@ -308,19 +308,17 @@ struct tf_parser
static attribute_map get_attributes(const tensorflow::NodeDef& node) static attribute_map get_attributes(const tensorflow::NodeDef& node)
{ {
attribute_map result; attribute_map result;
for (auto&& attr : node.attr()) for(auto&& attr : node.attr())
{ {
result[attr.first] = attr.second; result[attr.first] = attr.second;
} }
return result; return result;
} }
static std::string get_name(const tensorflow::NodeDef& node) static std::string get_name(const tensorflow::NodeDef& node) { return node.name(); }
{
return node.name();
}
static node_map get_nodes(const tensorflow::GraphDef& graph, std::vector<tensorflow::NodeDef>& input_nodes) static node_map get_nodes(const tensorflow::GraphDef& graph,
std::vector<tensorflow::NodeDef>& input_nodes)
{ {
node_map result; node_map result;
for(auto&& node : graph.node()) for(auto&& node : graph.node())
...@@ -381,8 +379,7 @@ struct tf_parser ...@@ -381,8 +379,7 @@ struct tf_parser
break; // throw std::runtime_error("Unsupported type VARIANT"); break; // throw std::runtime_error("Unsupported type VARIANT");
case tensorflow::DataType::DT_UINT32: shape_type = shape::uint32_type; break; case tensorflow::DataType::DT_UINT32: shape_type = shape::uint32_type; break;
case tensorflow::DataType::DT_UINT64: shape_type = shape::uint64_type; break; case tensorflow::DataType::DT_UINT64: shape_type = shape::uint64_type; break;
default: default: break;
break;
} }
return shape_type; return shape_type;
} }
...@@ -397,27 +394,32 @@ struct tf_parser ...@@ -397,27 +394,32 @@ struct tf_parser
if(!t.tensor_content().empty()) // has raw data if(!t.tensor_content().empty()) // has raw data
{ {
const std::string&s = t.tensor_content(); const std::string& s = t.tensor_content();
switch(t.dtype()) switch(t.dtype())
{ {
case tensorflow::DataType::DT_INVALID: throw std::runtime_error(""); case tensorflow::DataType::DT_INVALID: throw std::runtime_error("");
case tensorflow::DataType::DT_FLOAT: return literal{{shape::float_type, dims}, s.data()}; case tensorflow::DataType::DT_FLOAT:
return literal{{shape::float_type, dims}, s.data()};
case tensorflow::DataType::DT_UINT8: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT8: throw std::runtime_error("");
case tensorflow::DataType::DT_INT8: return literal{{shape::int32_type, dims}, s.data()}; case tensorflow::DataType::DT_INT8: return literal{{shape::int32_type, dims}, s.data()};
case tensorflow::DataType::DT_UINT16: return literal{{shape::int32_type, dims}, s.data()}; case tensorflow::DataType::DT_UINT16:
case tensorflow::DataType::DT_INT16: return literal{{shape::int32_type, dims}, s.data()}; return literal{{shape::int32_type, dims}, s.data()};
case tensorflow::DataType::DT_INT32: return literal{{shape::int32_type, dims}, s.data()}; case tensorflow::DataType::DT_INT16:
case tensorflow::DataType::DT_INT64: return literal{{shape::int64_type, dims}, s.data()}; return literal{{shape::int32_type, dims}, s.data()};
case tensorflow::DataType::DT_INT32:
return literal{{shape::int32_type, dims}, s.data()};
case tensorflow::DataType::DT_INT64:
return literal{{shape::int64_type, dims}, s.data()};
case tensorflow::DataType::DT_STRING: throw std::runtime_error(""); case tensorflow::DataType::DT_STRING: throw std::runtime_error("");
case tensorflow::DataType::DT_BOOL: return literal{{shape::int32_type, dims}, s.data()}; case tensorflow::DataType::DT_BOOL: return literal{{shape::int32_type, dims}, s.data()};
case tensorflow::DataType::DT_HALF: return literal{{shape::half_type, dims}, s.data()}; case tensorflow::DataType::DT_HALF: return literal{{shape::half_type, dims}, s.data()};
case tensorflow::DataType::DT_DOUBLE: return literal{{shape::double_type, dims}, s.data()}; case tensorflow::DataType::DT_DOUBLE:
return literal{{shape::double_type, dims}, s.data()};
case tensorflow::DataType::DT_UINT32: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT32: throw std::runtime_error("");
case tensorflow::DataType::DT_UINT64: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT64: throw std::runtime_error("");
case tensorflow::DataType::DT_COMPLEX64: throw std::runtime_error(""); case tensorflow::DataType::DT_COMPLEX64: throw std::runtime_error("");
case tensorflow::DataType::DT_COMPLEX128: throw std::runtime_error(""); case tensorflow::DataType::DT_COMPLEX128: throw std::runtime_error("");
default: default: break;
break;
} }
MIGRAPHX_THROW("Invalid tensor type"); MIGRAPHX_THROW("Invalid tensor type");
} }
...@@ -449,11 +451,9 @@ struct tf_parser ...@@ -449,11 +451,9 @@ struct tf_parser
case tensorflow::DataType::DT_UINT64: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT64: throw std::runtime_error("");
case tensorflow::DataType::DT_COMPLEX64: throw std::runtime_error(""); case tensorflow::DataType::DT_COMPLEX64: throw std::runtime_error("");
case tensorflow::DataType::DT_COMPLEX128: throw std::runtime_error(""); case tensorflow::DataType::DT_COMPLEX128: throw std::runtime_error("");
default: default: break;
break;
} }
MIGRAPHX_THROW("Invalid tensor type"); MIGRAPHX_THROW("Invalid tensor type");
} }
static std::vector<size_t> parse_dims(const tensorflow::TensorShapeProto& s) static std::vector<size_t> parse_dims(const tensorflow::TensorShapeProto& s)
...@@ -466,9 +466,6 @@ struct tf_parser ...@@ -466,9 +466,6 @@ struct tf_parser
} }
return dims; return dims;
} }
}; };
program parse_tf(const std::string& name, bool is_nhwc) program parse_tf(const std::string& name, bool is_nhwc)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment