Commit c0936f8f authored by Khalique's avatar Khalique
Browse files

made changes according to feedback

parent 7e2cf4e8
...@@ -36,7 +36,7 @@ struct tf_parser ...@@ -36,7 +36,7 @@ struct tf_parser
std::unordered_map<std::string, op_func> ops; std::unordered_map<std::string, op_func> ops;
std::vector<size_t> parse_axes(attribute_map& attributes, const std::string& s) const std::vector<size_t> parse_axes(const attribute_map& attributes, const std::string& s) const
{ {
auto attrs = attributes.at(s).list().i(); auto attrs = attributes.at(s).list().i();
std::vector<size_t> axes; std::vector<size_t> axes;
...@@ -45,7 +45,7 @@ struct tf_parser ...@@ -45,7 +45,7 @@ struct tf_parser
{ {
for(size_t& axis : axes) for(size_t& axis : axes)
{ {
parse_axis(axis); axis = parse_axis(axis);
} }
} }
return axes; return axes;
...@@ -57,27 +57,27 @@ struct tf_parser ...@@ -57,27 +57,27 @@ struct tf_parser
std::vector<T> new_data(prev_data.size()); std::vector<T> new_data(prev_data.size());
for(size_t i = 0; i < new_data.size(); i++) for(size_t i = 0; i < new_data.size(); i++)
{ {
auto new_idx = i; auto new_idx = parse_axis(i);
parse_axis(new_idx);
new_data.at(new_idx) = prev_data.at(i); new_data.at(new_idx) = prev_data.at(i);
} }
prev_data = new_data; prev_data = new_data;
} }
template <class T> template <class T>
void parse_axis(T& dim) const T parse_axis(const T& dim) const
{ {
if(is_nhwc) if(is_nhwc)
{ {
switch(dim) switch(dim)
{ {
case 0: dim = 0; break; case 0: return 0;
case 1: dim = 2; break; case 1: return 2;
case 2: dim = 3; break; case 2: return 3;
case 3: dim = 1; break; case 3: return 1;
default: break; default: return T{dim};
} }
} }
return T{dim};
} }
std::vector<int64_t> get_axes(size_t num_axes) const std::vector<int64_t> get_axes(size_t num_axes) const
...@@ -224,8 +224,7 @@ struct tf_parser ...@@ -224,8 +224,7 @@ struct tf_parser
{ {
// get index for axis within args // get index for axis within args
size_t axis_idx = attributes.at("N").i(); size_t axis_idx = attributes.at("N").i();
size_t axis = args[axis_idx]->eval().at<int64_t>(); size_t axis = parse_axis(args[axis_idx]->eval().at<int64_t>());
parse_axis(axis);
op::concat op{axis}; op::concat op{axis};
// return only first N arguments (assuming last index is the axis value) // return only first N arguments (assuming last index is the axis value)
return prog.add_instruction( return prog.add_instruction(
...@@ -263,7 +262,6 @@ struct tf_parser ...@@ -263,7 +262,6 @@ struct tf_parser
{ {
std::vector<size_t> padding; std::vector<size_t> padding;
copy(attributes.at("explicit_paddings").list().i(), std::back_inserter(padding)); copy(attributes.at("explicit_paddings").list().i(), std::back_inserter(padding));
reorder_data(padding);
if(padding.size() != 4) if(padding.size() != 4)
{ {
MIGRAPHX_THROW("padding should have 4 values"); MIGRAPHX_THROW("padding should have 4 values");
...@@ -369,7 +367,7 @@ struct tf_parser ...@@ -369,7 +367,7 @@ struct tf_parser
op::reshape op; op::reshape op;
if(args.size() != 2) if(args.size() != 2)
MIGRAPHX_THROW("reshape needs 2 arguments (input, new_shape)"); MIGRAPHX_THROW("reshape needs 2 arguments (input, new_shape)");
literal s = args[1]->get_literal(); auto s = args[1]->eval();
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); }); s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
return prog.add_instruction(op, args[0]); return prog.add_instruction(op, args[0]);
} }
...@@ -674,7 +672,12 @@ struct tf_parser ...@@ -674,7 +672,12 @@ struct tf_parser
{ {
std::vector<int> data_int32 = get_data_vals(t.half_val(), shape_size); std::vector<int> data_int32 = get_data_vals(t.half_val(), shape_size);
std::vector<uint16_t> data_uint16(data_int32.begin(), data_int32.end()); std::vector<uint16_t> data_uint16(data_int32.begin(), data_int32.end());
return literal{{shape::half_type, dims}, data_uint16}; std::vector<half> data_half;
std::transform(data_uint16.begin(),
data_uint16.end(),
std::back_inserter(data_half),
[](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
return literal{{shape::half_type, dims}, data_half};
} }
case tensorflow::DataType::DT_DOUBLE: case tensorflow::DataType::DT_DOUBLE:
return literal{{shape::double_type, dims}, get_data_vals(t.double_val(), shape_size)}; return literal{{shape::double_type, dims}, get_data_vals(t.double_val(), shape_size)};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment