Commit 877b8c44 authored by Khalique's avatar Khalique
Browse files

formatting

parent 4f5024b7
......@@ -83,7 +83,8 @@ struct convolution
// static_cast<std::size_t>(std::ceil(
// static_cast<double>(input.lens()[2] - weights.lens()[2] + 1) / stride[0])),
// static_cast<std::size_t>(std::ceil(
// static_cast<double>(input.lens()[3] - weights.lens()[3] + 1) / stride[1]))}};
// static_cast<double>(input.lens()[3] - weights.lens()[3] + 1) /
// stride[1]))}};
// }
// else
// {
......
......@@ -56,13 +56,13 @@ struct pooling
input.lens()[1],
std::size_t(std::max<std::ptrdiff_t>(
1,
floor_divide<std::ptrdiff_t>(
input.lens()[2] + 2 * padding[0] - lengths[0], stride[0]) +
floor_divide<std::ptrdiff_t>(input.lens()[2] + 2 * padding[0] - lengths[0],
stride[0]) +
1)),
std::size_t(std::max<std::ptrdiff_t>(
1,
floor_divide<std::ptrdiff_t>(
input.lens()[3] + 2 * padding[1] - lengths[1], stride[1]) +
floor_divide<std::ptrdiff_t>(input.lens()[3] + 2 * padding[1] - lengths[1],
stride[1]) +
1)),
}};
// }
......@@ -83,10 +83,12 @@ struct pooling
// input.lens()[1],
// std::size_t(std::max<std::ptrdiff_t>(
// 1,
// floor_divide<std::ptrdiff_t>(input.lens()[2] - lengths[0], stride[0]) + 1)),
// floor_divide<std::ptrdiff_t>(input.lens()[2] - lengths[0], stride[0]) +
// 1)),
// std::size_t(std::max<std::ptrdiff_t>(
// 1,
// floor_divide<std::ptrdiff_t>(input.lens()[3] - lengths[1], stride[1]) + 1)),
// floor_divide<std::ptrdiff_t>(input.lens()[3] - lengths[1], stride[1]) +
// 1)),
// }};
// }
// else
......
......@@ -15,16 +15,14 @@ argument
pad(hipStream_t stream, argument result, argument arg1, float value, std::vector<std::int64_t> pads)
{
std::size_t nelements = arg1.get_shape().elements();
if(float_equal(value,std::numeric_limits<float>::lowest()))
if(float_equal(value, std::numeric_limits<float>::lowest()))
{
visit_all(result)([&](auto output) {
auto* outptr = device_cast(output.data());
auto val = device_cast(std::numeric_limits<typename
decltype(output)::value_type>::lowest());
auto val =
device_cast(std::numeric_limits<typename decltype(output)::value_type>::lowest());
gs_launch(stream, nelements)([=](auto i) {
outptr[i] = val;
});
gs_launch(stream, nelements)([=](auto i) { outptr[i] = val; });
});
}
......@@ -32,9 +30,7 @@ pad(hipStream_t stream, argument result, argument arg1, float value, std::vector
{
visit_all(result)([&](auto output) {
auto* outptr = device_cast(output.data());
gs_launch(stream, nelements)([=](auto i) {
outptr[i] = value;
});
gs_launch(stream, nelements)([=](auto i) { outptr[i] = value; });
});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment