Commit 6e058792 authored by Khalique's avatar Khalique
Browse files

formatting

parent c8a91e20
......@@ -24,10 +24,10 @@ inline namespace MIGRAPHX_INLINE_NS {
struct tf_parser
{
using attribute_map = std::unordered_map<std::string, tensorflow::AttrValue>;
using node_map = std::unordered_map<std::string, tensorflow::NodeDef>;
using node_map = std::unordered_map<std::string, tensorflow::NodeDef>;
// using input_node_map = std::unordered_map<std::string, std::unordered_set<std::string>>;
using op_func = std::function<instruction_ref(attribute_map, std::vector<instruction_ref>)>;
node_map nodes;
std::vector<tensorflow::NodeDef> input_nodes;
std::unordered_map<std::string, instruction_ref> instructions;
......@@ -130,7 +130,7 @@ struct tf_parser
{
epsilon = attributes.at("epsilon").f();
}
op::batch_norm_inference op{epsilon, momentum, bn_mode};
return prog.add_instruction(op, std::move(args));
}
......@@ -140,7 +140,7 @@ struct tf_parser
{
// get index for axis within args
std::size_t axis_idx = attributes.at("N").i();
std::size_t axis = args[axis_idx]->eval().at<int64_t>();
std::size_t axis = args[axis_idx]->eval().at<int64_t>();
op::concat op{axis};
return prog.add_instruction(op, std::move(args));
}
......@@ -164,7 +164,7 @@ struct tf_parser
{
op.padding_mode = op::convolution::same;
}
else if (pad_mode.find("EXPLICIT") != std::string::npos)
else if(pad_mode.find("EXPLICIT") != std::string::npos)
{
std::vector<std::size_t> padding(4);
copy(attributes.at("explicit_paddings").list().i(), padding.begin());
......@@ -200,7 +200,7 @@ struct tf_parser
// std::vector<instruction_ref> args)
// {
// op::pooling op{starts_with(name, "Max") ? "max" : "average"};
// if(contains(attributes, "pads"))
// {
// std::vector<std::size_t> padding(4);
......@@ -254,12 +254,12 @@ struct tf_parser
nodes = get_nodes(graph, input_nodes);
for(auto&& input : input_nodes)
{
const std::string& name = input.name();
const std::string& name = input.name();
attribute_map input_attrs = get_attributes(input);
shape::type_t shape_type = parse_type(input_attrs.at("dtype").type());
std::vector<size_t> dims = parse_dims(input_attrs.at("shape").shape());
shape s = shape{shape_type, dims};
instructions[name] = prog.add_parameter(name, s);
shape::type_t shape_type = parse_type(input_attrs.at("dtype").type());
std::vector<size_t> dims = parse_dims(input_attrs.at("shape").shape());
shape s = shape{shape_type, dims};
instructions[name] = prog.add_parameter(name, s);
if(is_nhwc)
{
// nhwc to nchw
......@@ -308,19 +308,17 @@ struct tf_parser
static attribute_map get_attributes(const tensorflow::NodeDef& node)
{
attribute_map result;
for (auto&& attr : node.attr())
for(auto&& attr : node.attr())
{
result[attr.first] = attr.second;
}
return result;
}
static std::string get_name(const tensorflow::NodeDef& node)
{
return node.name();
}
static std::string get_name(const tensorflow::NodeDef& node) { return node.name(); }
static node_map get_nodes(const tensorflow::GraphDef& graph, std::vector<tensorflow::NodeDef>& input_nodes)
static node_map get_nodes(const tensorflow::GraphDef& graph,
std::vector<tensorflow::NodeDef>& input_nodes)
{
node_map result;
for(auto&& node : graph.node())
......@@ -381,8 +379,7 @@ struct tf_parser
break; // throw std::runtime_error("Unsupported type VARIANT");
case tensorflow::DataType::DT_UINT32: shape_type = shape::uint32_type; break;
case tensorflow::DataType::DT_UINT64: shape_type = shape::uint64_type; break;
default:
break;
default: break;
}
return shape_type;
}
......@@ -397,27 +394,32 @@ struct tf_parser
if(!t.tensor_content().empty()) // has raw data
{
const std::string&s = t.tensor_content();
const std::string& s = t.tensor_content();
switch(t.dtype())
{
case tensorflow::DataType::DT_INVALID: throw std::runtime_error("");
case tensorflow::DataType::DT_FLOAT: return literal{{shape::float_type, dims}, s.data()};
case tensorflow::DataType::DT_FLOAT:
return literal{{shape::float_type, dims}, s.data()};
case tensorflow::DataType::DT_UINT8: throw std::runtime_error("");
case tensorflow::DataType::DT_INT8: return literal{{shape::int32_type, dims}, s.data()};
case tensorflow::DataType::DT_UINT16: return literal{{shape::int32_type, dims}, s.data()};
case tensorflow::DataType::DT_INT16: return literal{{shape::int32_type, dims}, s.data()};
case tensorflow::DataType::DT_INT32: return literal{{shape::int32_type, dims}, s.data()};
case tensorflow::DataType::DT_INT64: return literal{{shape::int64_type, dims}, s.data()};
case tensorflow::DataType::DT_UINT16:
return literal{{shape::int32_type, dims}, s.data()};
case tensorflow::DataType::DT_INT16:
return literal{{shape::int32_type, dims}, s.data()};
case tensorflow::DataType::DT_INT32:
return literal{{shape::int32_type, dims}, s.data()};
case tensorflow::DataType::DT_INT64:
return literal{{shape::int64_type, dims}, s.data()};
case tensorflow::DataType::DT_STRING: throw std::runtime_error("");
case tensorflow::DataType::DT_BOOL: return literal{{shape::int32_type, dims}, s.data()};
case tensorflow::DataType::DT_HALF: return literal{{shape::half_type, dims}, s.data()};
case tensorflow::DataType::DT_DOUBLE: return literal{{shape::double_type, dims}, s.data()};
case tensorflow::DataType::DT_DOUBLE:
return literal{{shape::double_type, dims}, s.data()};
case tensorflow::DataType::DT_UINT32: throw std::runtime_error("");
case tensorflow::DataType::DT_UINT64: throw std::runtime_error("");
case tensorflow::DataType::DT_COMPLEX64: throw std::runtime_error("");
case tensorflow::DataType::DT_COMPLEX128: throw std::runtime_error("");
default:
break;
default: break;
}
MIGRAPHX_THROW("Invalid tensor type");
}
......@@ -449,11 +451,9 @@ struct tf_parser
case tensorflow::DataType::DT_UINT64: throw std::runtime_error("");
case tensorflow::DataType::DT_COMPLEX64: throw std::runtime_error("");
case tensorflow::DataType::DT_COMPLEX128: throw std::runtime_error("");
default:
break;
default: break;
}
MIGRAPHX_THROW("Invalid tensor type");
}
static std::vector<size_t> parse_dims(const tensorflow::TensorShapeProto& s)
......@@ -466,9 +466,6 @@ struct tf_parser
}
return dims;
}
};
program parse_tf(const std::string& name, bool is_nhwc)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment