Commit bca45a1a authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

prototype changes

parent c900e382
...@@ -96,7 +96,8 @@ struct convolution ...@@ -96,7 +96,8 @@ struct convolution
} }
if(not x_shape.dynamic() and not w_shape.dynamic() and if(not x_shape.dynamic() and not w_shape.dynamic() and
x_shape.lens().at(1) != (w_shape.lens().at(1) * group)) (x_shape.lens().at(1) != (w_shape.lens().at(1) * group) and
x_shape.lens().back() != (w_shape.lens().back() * group)))
MIGRAPHX_THROW("CONVOLUTION: mismatched channel numbers"); MIGRAPHX_THROW("CONVOLUTION: mismatched channel numbers");
if(x_shape.dynamic() or w_shape.dynamic()) if(x_shape.dynamic() or w_shape.dynamic())
...@@ -130,9 +131,11 @@ struct convolution ...@@ -130,9 +131,11 @@ struct convolution
// when padding is {x0_begin, x1_begin, ... x0_end , x1_end, ...} // when padding is {x0_begin, x1_begin, ... x0_end , x1_end, ...}
padding_factor = padding[i] + padding[i + num_spatial_dims]; padding_factor = padding[i] + padding[i + num_spatial_dims];
} }
// k y x c
ret.push_back(std::size_t(std::max<std::ptrdiff_t>( ret.push_back(std::size_t(std::max<std::ptrdiff_t>(
1, 1,
(x_lens[i + 2] - (1 + dilation[i] * (w_lens[i + 2] - 1)) + padding_factor) / (x_lens[i + 1] - (1 + dilation[i] * (*(w_lens.rbegin() + i) - 1)) + padding_factor) /
stride[i] + stride[i] +
1))); 1)));
} }
...@@ -200,11 +203,13 @@ struct convolution ...@@ -200,11 +203,13 @@ struct convolution
shape fixed_compute_shape(shape x_shape, shape w_shape) const shape fixed_compute_shape(shape x_shape, shape w_shape) const
{ {
std::vector<size_t> output_lens{x_shape.lens()[0], w_shape.lens()[0]}; std::vector<size_t> output_lens{x_shape.lens()[0]};
// std::vector<size_t> output_lens{x_shape.lens()[0], w_shape.lens()[0]};
auto spatial_lens = calc_conv_lens(x_shape.lens(), w_shape.lens()); auto spatial_lens = calc_conv_lens(x_shape.lens(), w_shape.lens());
std::for_each(spatial_lens.begin(), spatial_lens.end(), [&output_lens](auto x) { std::for_each(spatial_lens.begin(), spatial_lens.end(), [&output_lens](auto x) {
output_lens.push_back(x); output_lens.push_back(x);
}); });
output_lens.push_back(w_shape.lens()[0]);
return x_shape.with_lens(output_lens); return x_shape.with_lens(output_lens);
} }
......
...@@ -115,9 +115,9 @@ struct miopen_convolution ...@@ -115,9 +115,9 @@ struct miopen_convolution
tensor_args, tensor_args,
args[2].implicit(), args[2].implicit(),
workspace_size); workspace_size);
if(status != miopenStatusSuccess) // if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen " + op.name() + // MIGRAPHX_THROW("MIOpen " + op.name() +
" : running convolution using find_2.0 failed"); // " : running convolution using find_2.0 failed");
return args[3]; return args[3];
} }
......
...@@ -137,7 +137,10 @@ inline tensor_descriptor make_tensor(const migraphx::shape& os, bool pack = fals ...@@ -137,7 +137,10 @@ inline tensor_descriptor make_tensor(const migraphx::shape& os, bool pack = fals
{ {
MIGRAPHX_THROW("MAKE_TENSOR: unsupported type"); MIGRAPHX_THROW("MAKE_TENSOR: unsupported type");
} }
miopenSetTensorDescriptor(t.get(), d, s.lens().size(), lens.data(), strides.data());
// miopenSetTensorDescriptor(t.get(), d, s.lens().size(), lens.data(), strides.data());
miopenSetNdTensorDescriptorWithLayout(t.get(), d, miopenTensorLayout_t::miopenTensorNHWC, lens.data(), lens.size());
return t; return t;
} }
......
...@@ -80,17 +80,17 @@ struct op_parser : auto_register<register_op_parser_action, Derived> ...@@ -80,17 +80,17 @@ struct op_parser : auto_register<register_op_parser_action, Derived>
{ {
std::vector<instruction_ref> result; std::vector<instruction_ref> result;
auto& self = static_cast<const Derived&>(*this); auto& self = static_cast<const Derived&>(*this);
if(self.transpose()) // if(self.transpose())
{ // {
result = implicit_multi_op(self.parse(opd, parser, info, parser.to_nchw(args))); // result = implicit_multi_op(self.parse(opd, parser, info, parser.to_nchw(args)));
std::transform(result.begin(), result.end(), result.begin(), [&](auto ins) { // std::transform(result.begin(), result.end(), result.begin(), [&](auto ins) {
return parser.to_nhwc(ins); // return parser.to_nhwc(ins);
}); // });
} // }
else // else
{ // {
result = implicit_multi_op(self.parse(opd, parser, info, args)); result = implicit_multi_op(self.parse(opd, parser, info, args));
} // }
return result; return result;
} }
}; };
......
...@@ -47,7 +47,7 @@ struct parse_conv : op_parser<parse_conv> ...@@ -47,7 +47,7 @@ struct parse_conv : op_parser<parse_conv>
{ {
std::vector<size_t> stride; std::vector<size_t> stride;
copy(info.attributes.at("strides").list().i(), std::back_inserter(stride)); copy(info.attributes.at("strides").list().i(), std::back_inserter(stride));
parser.reorder_data(stride); // parser.reorder_data(stride);
if(stride.size() != 4) if(stride.size() != 4)
{ {
MIGRAPHX_THROW("strides should have 4 values"); MIGRAPHX_THROW("strides should have 4 values");
...@@ -59,7 +59,7 @@ struct parse_conv : op_parser<parse_conv> ...@@ -59,7 +59,7 @@ struct parse_conv : op_parser<parse_conv>
{ {
std::vector<size_t> dilation; std::vector<size_t> dilation;
copy(info.attributes.at("dilations").list().i(), std::back_inserter(dilation)); copy(info.attributes.at("dilations").list().i(), std::back_inserter(dilation));
parser.reorder_data(dilation); // parser.reorder_data(dilation);
if(dilation.size() != 4) if(dilation.size() != 4)
{ {
MIGRAPHX_THROW("dilation should have 4 values"); MIGRAPHX_THROW("dilation should have 4 values");
......
...@@ -60,8 +60,10 @@ struct parse_pack : op_parser<parse_pack> ...@@ -60,8 +60,10 @@ struct parse_pack : op_parser<parse_pack>
[&](instruction_ref arg) { [&](instruction_ref arg) {
return info.add_instruction(make_op("unsqueeze", {{"axes", {axis}}}), arg); return info.add_instruction(make_op("unsqueeze", {{"axes", {axis}}}), arg);
}); });
return parser.to_nhwc( return
info.add_instruction(make_op("concat", {{"axis", axis}}), unsqueezed_args)); info.add_instruction(make_op("concat", {{"axis", axis}}), unsqueezed_args);
// return parser.to_nhwc(
// info.add_instruction(make_op("concat", {{"axis", axis}}), unsqueezed_args));
} }
}; };
......
...@@ -71,7 +71,10 @@ instruction_ref tf_parser::to_nchw(instruction_ref ins) const ...@@ -71,7 +71,10 @@ instruction_ref tf_parser::to_nchw(instruction_ref ins) const
instruction_ref tf_parser::to_kcxy(instruction_ref ins) const instruction_ref tf_parser::to_kcxy(instruction_ref ins) const
{ {
return mm->add_instruction(make_op("transpose", {{"permutation", {3, 2, 0, 1}}}), ins); // x y c k
// k c x y
// k y x c
return mm->add_instruction(make_op("transpose", {{"permutation", {3, 1, 0, 2}}}), ins);
} }
std::vector<instruction_ref> tf_parser::to_nchw(const std::vector<instruction_ref>& args) const std::vector<instruction_ref> tf_parser::to_nchw(const std::vector<instruction_ref>& args) const
...@@ -282,17 +285,19 @@ void tf_parser::parse_graph(const tensorflow::GraphDef& graph) ...@@ -282,17 +285,19 @@ void tf_parser::parse_graph(const tensorflow::GraphDef& graph)
} }
else else
{ {
if(is_nhwc and dims.size() >= 4) // if(is_nhwc and dims.size() >= 4)
{ // {
this->reorder_data(dims); // this->reorder_data(dims);
} // }
std::transform(dims.begin(), dims.end(), dims.begin(), [&](auto dim) { std::transform(dims.begin(), dims.end(), dims.begin(), [&](auto dim) {
return static_cast<int>(dim) <= 0 ? batch_size : dim; return static_cast<int>(dim) <= 0 ? batch_size : dim;
}); });
} }
shape s = shape{shape_type, dims}; shape s = shape{shape_type, dims};
instructions[name] = to_nhwc(mm->add_parameter(name, s)); instructions[name] = mm->add_parameter(name, s);
// instructions[name] = to_nhwc(mm->add_parameter(name, s));
} }
for(auto&& p : nodes) for(auto&& p : nodes)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment