Commit 6e058792 authored by Khalique's avatar Khalique
Browse files

formatting

parent c8a91e20
...@@ -24,10 +24,10 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -24,10 +24,10 @@ inline namespace MIGRAPHX_INLINE_NS {
struct tf_parser struct tf_parser
{ {
using attribute_map = std::unordered_map<std::string, tensorflow::AttrValue>; using attribute_map = std::unordered_map<std::string, tensorflow::AttrValue>;
using node_map = std::unordered_map<std::string, tensorflow::NodeDef>; using node_map = std::unordered_map<std::string, tensorflow::NodeDef>;
// using input_node_map = std::unordered_map<std::string, std::unordered_set<std::string>>; // using input_node_map = std::unordered_map<std::string, std::unordered_set<std::string>>;
using op_func = std::function<instruction_ref(attribute_map, std::vector<instruction_ref>)>; using op_func = std::function<instruction_ref(attribute_map, std::vector<instruction_ref>)>;
node_map nodes; node_map nodes;
std::vector<tensorflow::NodeDef> input_nodes; std::vector<tensorflow::NodeDef> input_nodes;
std::unordered_map<std::string, instruction_ref> instructions; std::unordered_map<std::string, instruction_ref> instructions;
...@@ -130,7 +130,7 @@ struct tf_parser ...@@ -130,7 +130,7 @@ struct tf_parser
{ {
epsilon = attributes.at("epsilon").f(); epsilon = attributes.at("epsilon").f();
} }
op::batch_norm_inference op{epsilon, momentum, bn_mode}; op::batch_norm_inference op{epsilon, momentum, bn_mode};
return prog.add_instruction(op, std::move(args)); return prog.add_instruction(op, std::move(args));
} }
...@@ -140,7 +140,7 @@ struct tf_parser ...@@ -140,7 +140,7 @@ struct tf_parser
{ {
// get index for axis within args // get index for axis within args
std::size_t axis_idx = attributes.at("N").i(); std::size_t axis_idx = attributes.at("N").i();
std::size_t axis = args[axis_idx]->eval().at<int64_t>(); std::size_t axis = args[axis_idx]->eval().at<int64_t>();
op::concat op{axis}; op::concat op{axis};
return prog.add_instruction(op, std::move(args)); return prog.add_instruction(op, std::move(args));
} }
...@@ -164,7 +164,7 @@ struct tf_parser ...@@ -164,7 +164,7 @@ struct tf_parser
{ {
op.padding_mode = op::convolution::same; op.padding_mode = op::convolution::same;
} }
else if (pad_mode.find("EXPLICIT") != std::string::npos) else if(pad_mode.find("EXPLICIT") != std::string::npos)
{ {
std::vector<std::size_t> padding(4); std::vector<std::size_t> padding(4);
copy(attributes.at("explicit_paddings").list().i(), padding.begin()); copy(attributes.at("explicit_paddings").list().i(), padding.begin());
...@@ -200,7 +200,7 @@ struct tf_parser ...@@ -200,7 +200,7 @@ struct tf_parser
// std::vector<instruction_ref> args) // std::vector<instruction_ref> args)
// { // {
// op::pooling op{starts_with(name, "Max") ? "max" : "average"}; // op::pooling op{starts_with(name, "Max") ? "max" : "average"};
// if(contains(attributes, "pads")) // if(contains(attributes, "pads"))
// { // {
// std::vector<std::size_t> padding(4); // std::vector<std::size_t> padding(4);
...@@ -254,12 +254,12 @@ struct tf_parser ...@@ -254,12 +254,12 @@ struct tf_parser
nodes = get_nodes(graph, input_nodes); nodes = get_nodes(graph, input_nodes);
for(auto&& input : input_nodes) for(auto&& input : input_nodes)
{ {
const std::string& name = input.name(); const std::string& name = input.name();
attribute_map input_attrs = get_attributes(input); attribute_map input_attrs = get_attributes(input);
shape::type_t shape_type = parse_type(input_attrs.at("dtype").type()); shape::type_t shape_type = parse_type(input_attrs.at("dtype").type());
std::vector<size_t> dims = parse_dims(input_attrs.at("shape").shape()); std::vector<size_t> dims = parse_dims(input_attrs.at("shape").shape());
shape s = shape{shape_type, dims}; shape s = shape{shape_type, dims};
instructions[name] = prog.add_parameter(name, s); instructions[name] = prog.add_parameter(name, s);
if(is_nhwc) if(is_nhwc)
{ {
// nhwc to nchw // nhwc to nchw
...@@ -308,19 +308,17 @@ struct tf_parser ...@@ -308,19 +308,17 @@ struct tf_parser
static attribute_map get_attributes(const tensorflow::NodeDef& node) static attribute_map get_attributes(const tensorflow::NodeDef& node)
{ {
attribute_map result; attribute_map result;
for (auto&& attr : node.attr()) for(auto&& attr : node.attr())
{ {
result[attr.first] = attr.second; result[attr.first] = attr.second;
} }
return result; return result;
} }
static std::string get_name(const tensorflow::NodeDef& node) static std::string get_name(const tensorflow::NodeDef& node) { return node.name(); }
{
return node.name();
}
static node_map get_nodes(const tensorflow::GraphDef& graph, std::vector<tensorflow::NodeDef>& input_nodes) static node_map get_nodes(const tensorflow::GraphDef& graph,
std::vector<tensorflow::NodeDef>& input_nodes)
{ {
node_map result; node_map result;
for(auto&& node : graph.node()) for(auto&& node : graph.node())
...@@ -381,8 +379,7 @@ struct tf_parser ...@@ -381,8 +379,7 @@ struct tf_parser
break; // throw std::runtime_error("Unsupported type VARIANT"); break; // throw std::runtime_error("Unsupported type VARIANT");
case tensorflow::DataType::DT_UINT32: shape_type = shape::uint32_type; break; case tensorflow::DataType::DT_UINT32: shape_type = shape::uint32_type; break;
case tensorflow::DataType::DT_UINT64: shape_type = shape::uint64_type; break; case tensorflow::DataType::DT_UINT64: shape_type = shape::uint64_type; break;
default: default: break;
break;
} }
return shape_type; return shape_type;
} }
...@@ -397,27 +394,32 @@ struct tf_parser ...@@ -397,27 +394,32 @@ struct tf_parser
if(!t.tensor_content().empty()) // has raw data if(!t.tensor_content().empty()) // has raw data
{ {
const std::string&s = t.tensor_content(); const std::string& s = t.tensor_content();
switch(t.dtype()) switch(t.dtype())
{ {
case tensorflow::DataType::DT_INVALID: throw std::runtime_error(""); case tensorflow::DataType::DT_INVALID: throw std::runtime_error("");
case tensorflow::DataType::DT_FLOAT: return literal{{shape::float_type, dims}, s.data()}; case tensorflow::DataType::DT_FLOAT:
return literal{{shape::float_type, dims}, s.data()};
case tensorflow::DataType::DT_UINT8: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT8: throw std::runtime_error("");
case tensorflow::DataType::DT_INT8: return literal{{shape::int32_type, dims}, s.data()}; case tensorflow::DataType::DT_INT8: return literal{{shape::int32_type, dims}, s.data()};
case tensorflow::DataType::DT_UINT16: return literal{{shape::int32_type, dims}, s.data()}; case tensorflow::DataType::DT_UINT16:
case tensorflow::DataType::DT_INT16: return literal{{shape::int32_type, dims}, s.data()}; return literal{{shape::int32_type, dims}, s.data()};
case tensorflow::DataType::DT_INT32: return literal{{shape::int32_type, dims}, s.data()}; case tensorflow::DataType::DT_INT16:
case tensorflow::DataType::DT_INT64: return literal{{shape::int64_type, dims}, s.data()}; return literal{{shape::int32_type, dims}, s.data()};
case tensorflow::DataType::DT_INT32:
return literal{{shape::int32_type, dims}, s.data()};
case tensorflow::DataType::DT_INT64:
return literal{{shape::int64_type, dims}, s.data()};
case tensorflow::DataType::DT_STRING: throw std::runtime_error(""); case tensorflow::DataType::DT_STRING: throw std::runtime_error("");
case tensorflow::DataType::DT_BOOL: return literal{{shape::int32_type, dims}, s.data()}; case tensorflow::DataType::DT_BOOL: return literal{{shape::int32_type, dims}, s.data()};
case tensorflow::DataType::DT_HALF: return literal{{shape::half_type, dims}, s.data()}; case tensorflow::DataType::DT_HALF: return literal{{shape::half_type, dims}, s.data()};
case tensorflow::DataType::DT_DOUBLE: return literal{{shape::double_type, dims}, s.data()}; case tensorflow::DataType::DT_DOUBLE:
return literal{{shape::double_type, dims}, s.data()};
case tensorflow::DataType::DT_UINT32: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT32: throw std::runtime_error("");
case tensorflow::DataType::DT_UINT64: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT64: throw std::runtime_error("");
case tensorflow::DataType::DT_COMPLEX64: throw std::runtime_error(""); case tensorflow::DataType::DT_COMPLEX64: throw std::runtime_error("");
case tensorflow::DataType::DT_COMPLEX128: throw std::runtime_error(""); case tensorflow::DataType::DT_COMPLEX128: throw std::runtime_error("");
default: default: break;
break;
} }
MIGRAPHX_THROW("Invalid tensor type"); MIGRAPHX_THROW("Invalid tensor type");
} }
...@@ -449,11 +451,9 @@ struct tf_parser ...@@ -449,11 +451,9 @@ struct tf_parser
case tensorflow::DataType::DT_UINT64: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT64: throw std::runtime_error("");
case tensorflow::DataType::DT_COMPLEX64: throw std::runtime_error(""); case tensorflow::DataType::DT_COMPLEX64: throw std::runtime_error("");
case tensorflow::DataType::DT_COMPLEX128: throw std::runtime_error(""); case tensorflow::DataType::DT_COMPLEX128: throw std::runtime_error("");
default: default: break;
break;
} }
MIGRAPHX_THROW("Invalid tensor type"); MIGRAPHX_THROW("Invalid tensor type");
} }
static std::vector<size_t> parse_dims(const tensorflow::TensorShapeProto& s) static std::vector<size_t> parse_dims(const tensorflow::TensorShapeProto& s)
...@@ -466,9 +466,6 @@ struct tf_parser ...@@ -466,9 +466,6 @@ struct tf_parser
} }
return dims; return dims;
} }
}; };
program parse_tf(const std::string& name, bool is_nhwc) program parse_tf(const std::string& name, bool is_nhwc)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment