Commit f2ed2b3b authored by Paul's avatar Paul
Browse files

Format

parent 692ce4b0
...@@ -34,11 +34,11 @@ ...@@ -34,11 +34,11 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template<class Output, class T, class Padding, class Stride> template <class Output, class T, class Padding, class Stride>
void convolution(Output output, T input, T weights, Padding padding, Stride stride, int group) void convolution(Output output, T input, T weights, Padding padding, Stride stride, int group)
{ {
auto output_shape = output.get_shape(); auto output_shape = output.get_shape();
auto in_lens = input.get_shape().lens(); auto in_lens = input.get_shape().lens();
auto wei_lens = weights.get_shape().lens(); auto wei_lens = weights.get_shape().lens();
auto wei_n = wei_lens[0]; auto wei_n = wei_lens[0];
...@@ -68,22 +68,21 @@ void convolution(Output output, T input, T weights, Padding padding, Stride stri ...@@ -68,22 +68,21 @@ void convolution(Output output, T input, T weights, Padding padding, Stride stri
std::vector<std::ptrdiff_t> idx(idx_o.begin(), idx_o.end()); std::vector<std::ptrdiff_t> idx(idx_o.begin(), idx_o.end());
idx[1] = in_ch; idx[1] = in_ch;
std::transform(idx_win.begin() + 1, std::transform(idx_win.begin() + 1,
idx_win.end(), idx_win.end(),
win_start.begin(), win_start.begin(),
idx.begin() + 2, idx.begin() + 2,
[](std::ptrdiff_t ii, std::ptrdiff_t jj) { return ii + jj; }); [](std::ptrdiff_t ii, std::ptrdiff_t jj) { return ii + jj; });
std::vector<std::ptrdiff_t> idx_wei(idx_o.size()); std::vector<std::ptrdiff_t> idx_wei(idx_o.size());
idx_wei[0] = w; idx_wei[0] = w;
std::copy(idx_win.begin(), idx_win.end(), idx_wei.begin() + 1); std::copy(idx_win.begin(), idx_win.end(), idx_wei.begin() + 1);
if(std::all_of(idx.begin() + 2, idx.end(), [&](auto ii) { return ii >= 0; }) and if(std::all_of(idx.begin() + 2, idx.end(), [&](auto ii) { return ii >= 0; }) and
std::equal(idx.begin(), std::equal(idx.begin(),
idx.end(), idx.end(),
in_lens.begin(), in_lens.begin(),
in_lens.end(), in_lens.end(),
std::less<std::ptrdiff_t>{})) std::less<std::ptrdiff_t>{}))
{ {
acc += acc += input(idx.begin(), idx.end()) * weights(idx_wei.begin(), idx_wei.end());
input(idx.begin(), idx.end()) * weights(idx_wei.begin(), idx_wei.end());
} }
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment