"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "9ca525379aff242745f8e485f615392d2e5aabc7"
Commit bca45a1a authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

prototype changes

parent c900e382
......@@ -96,7 +96,8 @@ struct convolution
}
if(not x_shape.dynamic() and not w_shape.dynamic() and
x_shape.lens().at(1) != (w_shape.lens().at(1) * group))
(x_shape.lens().at(1) != (w_shape.lens().at(1) * group) and
x_shape.lens().back() != (w_shape.lens().back() * group)))
MIGRAPHX_THROW("CONVOLUTION: mismatched channel numbers");
if(x_shape.dynamic() or w_shape.dynamic())
......@@ -130,9 +131,11 @@ struct convolution
// when padding is {x0_begin, x1_begin, ... x0_end , x1_end, ...}
padding_factor = padding[i] + padding[i + num_spatial_dims];
}
// k y x c
ret.push_back(std::size_t(std::max<std::ptrdiff_t>(
1,
(x_lens[i + 2] - (1 + dilation[i] * (w_lens[i + 2] - 1)) + padding_factor) /
(x_lens[i + 1] - (1 + dilation[i] * (*(w_lens.rbegin() + i) - 1)) + padding_factor) /
stride[i] +
1)));
}
......@@ -200,11 +203,13 @@ struct convolution
shape fixed_compute_shape(shape x_shape, shape w_shape) const
{
std::vector<size_t> output_lens{x_shape.lens()[0], w_shape.lens()[0]};
std::vector<size_t> output_lens{x_shape.lens()[0]};
// std::vector<size_t> output_lens{x_shape.lens()[0], w_shape.lens()[0]};
auto spatial_lens = calc_conv_lens(x_shape.lens(), w_shape.lens());
std::for_each(spatial_lens.begin(), spatial_lens.end(), [&output_lens](auto x) {
output_lens.push_back(x);
});
output_lens.push_back(w_shape.lens()[0]);
return x_shape.with_lens(output_lens);
}
......
......@@ -115,9 +115,9 @@ struct miopen_convolution
tensor_args,
args[2].implicit(),
workspace_size);
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen " + op.name() +
" : running convolution using find_2.0 failed");
// if(status != miopenStatusSuccess)
// MIGRAPHX_THROW("MIOpen " + op.name() +
// " : running convolution using find_2.0 failed");
return args[3];
}
......
......@@ -137,8 +137,11 @@ inline tensor_descriptor make_tensor(const migraphx::shape& os, bool pack = fals
{
MIGRAPHX_THROW("MAKE_TENSOR: unsupported type");
}
miopenSetTensorDescriptor(t.get(), d, s.lens().size(), lens.data(), strides.data());
// miopenSetTensorDescriptor(t.get(), d, s.lens().size(), lens.data(), strides.data());
miopenSetNdTensorDescriptorWithLayout(t.get(), d, miopenTensorLayout_t::miopenTensorNHWC, lens.data(), lens.size());
return t;
}
......
......@@ -80,17 +80,17 @@ struct op_parser : auto_register<register_op_parser_action, Derived>
{
std::vector<instruction_ref> result;
auto& self = static_cast<const Derived&>(*this);
if(self.transpose())
{
result = implicit_multi_op(self.parse(opd, parser, info, parser.to_nchw(args)));
std::transform(result.begin(), result.end(), result.begin(), [&](auto ins) {
return parser.to_nhwc(ins);
});
}
else
{
// if(self.transpose())
// {
// result = implicit_multi_op(self.parse(opd, parser, info, parser.to_nchw(args)));
// std::transform(result.begin(), result.end(), result.begin(), [&](auto ins) {
// return parser.to_nhwc(ins);
// });
// }
// else
// {
result = implicit_multi_op(self.parse(opd, parser, info, args));
}
// }
return result;
}
};
......
......@@ -47,7 +47,7 @@ struct parse_conv : op_parser<parse_conv>
{
std::vector<size_t> stride;
copy(info.attributes.at("strides").list().i(), std::back_inserter(stride));
parser.reorder_data(stride);
// parser.reorder_data(stride);
if(stride.size() != 4)
{
MIGRAPHX_THROW("strides should have 4 values");
......@@ -59,7 +59,7 @@ struct parse_conv : op_parser<parse_conv>
{
std::vector<size_t> dilation;
copy(info.attributes.at("dilations").list().i(), std::back_inserter(dilation));
parser.reorder_data(dilation);
// parser.reorder_data(dilation);
if(dilation.size() != 4)
{
MIGRAPHX_THROW("dilation should have 4 values");
......
......@@ -60,8 +60,10 @@ struct parse_pack : op_parser<parse_pack>
[&](instruction_ref arg) {
return info.add_instruction(make_op("unsqueeze", {{"axes", {axis}}}), arg);
});
return parser.to_nhwc(
info.add_instruction(make_op("concat", {{"axis", axis}}), unsqueezed_args));
return
info.add_instruction(make_op("concat", {{"axis", axis}}), unsqueezed_args);
// return parser.to_nhwc(
// info.add_instruction(make_op("concat", {{"axis", axis}}), unsqueezed_args));
}
};
......
......@@ -71,7 +71,10 @@ instruction_ref tf_parser::to_nchw(instruction_ref ins) const
instruction_ref tf_parser::to_kcxy(instruction_ref ins) const
{
return mm->add_instruction(make_op("transpose", {{"permutation", {3, 2, 0, 1}}}), ins);
// x y c k
// k c x y
// k y x c
return mm->add_instruction(make_op("transpose", {{"permutation", {3, 1, 0, 2}}}), ins);
}
std::vector<instruction_ref> tf_parser::to_nchw(const std::vector<instruction_ref>& args) const
......@@ -282,17 +285,19 @@ void tf_parser::parse_graph(const tensorflow::GraphDef& graph)
}
else
{
if(is_nhwc and dims.size() >= 4)
{
this->reorder_data(dims);
}
// if(is_nhwc and dims.size() >= 4)
// {
// this->reorder_data(dims);
// }
std::transform(dims.begin(), dims.end(), dims.begin(), [&](auto dim) {
return static_cast<int>(dim) <= 0 ? batch_size : dim;
});
}
shape s = shape{shape_type, dims};
instructions[name] = to_nhwc(mm->add_parameter(name, s));
instructions[name] = mm->add_parameter(name, s);
// instructions[name] = to_nhwc(mm->add_parameter(name, s));
}
for(auto&& p : nodes)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment