Commit 877b8c44 authored by Khalique's avatar Khalique
Browse files

formatting

parent 4f5024b7
...@@ -83,7 +83,8 @@ struct convolution ...@@ -83,7 +83,8 @@ struct convolution
// static_cast<std::size_t>(std::ceil( // static_cast<std::size_t>(std::ceil(
// static_cast<double>(input.lens()[2] - weights.lens()[2] + 1) / stride[0])), // static_cast<double>(input.lens()[2] - weights.lens()[2] + 1) / stride[0])),
// static_cast<std::size_t>(std::ceil( // static_cast<std::size_t>(std::ceil(
// static_cast<double>(input.lens()[3] - weights.lens()[3] + 1) / stride[1]))}}; // static_cast<double>(input.lens()[3] - weights.lens()[3] + 1) /
// stride[1]))}};
// } // }
// else // else
// { // {
......
...@@ -56,13 +56,13 @@ struct pooling ...@@ -56,13 +56,13 @@ struct pooling
input.lens()[1], input.lens()[1],
std::size_t(std::max<std::ptrdiff_t>( std::size_t(std::max<std::ptrdiff_t>(
1, 1,
floor_divide<std::ptrdiff_t>( floor_divide<std::ptrdiff_t>(input.lens()[2] + 2 * padding[0] - lengths[0],
input.lens()[2] + 2 * padding[0] - lengths[0], stride[0]) + stride[0]) +
1)), 1)),
std::size_t(std::max<std::ptrdiff_t>( std::size_t(std::max<std::ptrdiff_t>(
1, 1,
floor_divide<std::ptrdiff_t>( floor_divide<std::ptrdiff_t>(input.lens()[3] + 2 * padding[1] - lengths[1],
input.lens()[3] + 2 * padding[1] - lengths[1], stride[1]) + stride[1]) +
1)), 1)),
}}; }};
// } // }
...@@ -83,10 +83,12 @@ struct pooling ...@@ -83,10 +83,12 @@ struct pooling
// input.lens()[1], // input.lens()[1],
// std::size_t(std::max<std::ptrdiff_t>( // std::size_t(std::max<std::ptrdiff_t>(
// 1, // 1,
// floor_divide<std::ptrdiff_t>(input.lens()[2] - lengths[0], stride[0]) + 1)), // floor_divide<std::ptrdiff_t>(input.lens()[2] - lengths[0], stride[0]) +
// 1)),
// std::size_t(std::max<std::ptrdiff_t>( // std::size_t(std::max<std::ptrdiff_t>(
// 1, // 1,
// floor_divide<std::ptrdiff_t>(input.lens()[3] - lengths[1], stride[1]) + 1)), // floor_divide<std::ptrdiff_t>(input.lens()[3] - lengths[1], stride[1]) +
// 1)),
// }}; // }};
// } // }
// else // else
......
...@@ -15,16 +15,14 @@ argument ...@@ -15,16 +15,14 @@ argument
pad(hipStream_t stream, argument result, argument arg1, float value, std::vector<std::int64_t> pads) pad(hipStream_t stream, argument result, argument arg1, float value, std::vector<std::int64_t> pads)
{ {
std::size_t nelements = arg1.get_shape().elements(); std::size_t nelements = arg1.get_shape().elements();
if(float_equal(value,std::numeric_limits<float>::lowest())) if(float_equal(value, std::numeric_limits<float>::lowest()))
{ {
visit_all(result)([&](auto output) { visit_all(result)([&](auto output) {
auto* outptr = device_cast(output.data()); auto* outptr = device_cast(output.data());
auto val = device_cast(std::numeric_limits<typename auto val =
decltype(output)::value_type>::lowest()); device_cast(std::numeric_limits<typename decltype(output)::value_type>::lowest());
gs_launch(stream, nelements)([=](auto i) { gs_launch(stream, nelements)([=](auto i) { outptr[i] = val; });
outptr[i] = val;
});
}); });
} }
...@@ -32,9 +30,7 @@ pad(hipStream_t stream, argument result, argument arg1, float value, std::vector ...@@ -32,9 +30,7 @@ pad(hipStream_t stream, argument result, argument arg1, float value, std::vector
{ {
visit_all(result)([&](auto output) { visit_all(result)([&](auto output) {
auto* outptr = device_cast(output.data()); auto* outptr = device_cast(output.data());
gs_launch(stream, nelements)([=](auto i) { gs_launch(stream, nelements)([=](auto i) { outptr[i] = value; });
outptr[i] = value;
});
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment