Commit 877b8c44 authored by Khalique's avatar Khalique
Browse files

formatting

parent 4f5024b7
......@@ -46,23 +46,23 @@ struct convolution
auto t = input.type();
// if(padding_mode == default_)
// {
return {t,
{
input.lens()[0],
weights.lens()[0],
std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[2] - (1 + dilation[0] * (weights.lens()[2] - 1)) +
2 * padding[0]) /
stride[0] +
1)),
std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[3] - (1 + dilation[1] * (weights.lens()[3] - 1)) +
2 * padding[1]) /
stride[1] +
1)),
}};
return {t,
{
input.lens()[0],
weights.lens()[0],
std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[2] - (1 + dilation[0] * (weights.lens()[2] - 1)) +
2 * padding[0]) /
stride[0] +
1)),
std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[3] - (1 + dilation[1] * (weights.lens()[3] - 1)) +
2 * padding[1]) /
stride[1] +
1)),
}};
// }
// else if(padding_mode == same)
// {
......@@ -83,7 +83,8 @@ struct convolution
// static_cast<std::size_t>(std::ceil(
// static_cast<double>(input.lens()[2] - weights.lens()[2] + 1) / stride[0])),
// static_cast<std::size_t>(std::ceil(
// static_cast<double>(input.lens()[3] - weights.lens()[3] + 1) / stride[1]))}};
// static_cast<double>(input.lens()[3] - weights.lens()[3] + 1) /
// stride[1]))}};
// }
// else
// {
......
......@@ -50,21 +50,21 @@ struct pooling
// if(padding_mode == default_)
// {
return {t,
{
input.lens()[0],
input.lens()[1],
std::size_t(std::max<std::ptrdiff_t>(
1,
floor_divide<std::ptrdiff_t>(
input.lens()[2] + 2 * padding[0] - lengths[0], stride[0]) +
1)),
std::size_t(std::max<std::ptrdiff_t>(
1,
floor_divide<std::ptrdiff_t>(
input.lens()[3] + 2 * padding[1] - lengths[1], stride[1]) +
1)),
}};
return {t,
{
input.lens()[0],
input.lens()[1],
std::size_t(std::max<std::ptrdiff_t>(
1,
floor_divide<std::ptrdiff_t>(input.lens()[2] + 2 * padding[0] - lengths[0],
stride[0]) +
1)),
std::size_t(std::max<std::ptrdiff_t>(
1,
floor_divide<std::ptrdiff_t>(input.lens()[3] + 2 * padding[1] - lengths[1],
stride[1]) +
1)),
}};
// }
// else if(padding_mode == same)
// {
......@@ -83,10 +83,12 @@ struct pooling
// input.lens()[1],
// std::size_t(std::max<std::ptrdiff_t>(
// 1,
// floor_divide<std::ptrdiff_t>(input.lens()[2] - lengths[0], stride[0]) + 1)),
// floor_divide<std::ptrdiff_t>(input.lens()[2] - lengths[0], stride[0]) +
// 1)),
// std::size_t(std::max<std::ptrdiff_t>(
// 1,
// floor_divide<std::ptrdiff_t>(input.lens()[3] - lengths[1], stride[1]) + 1)),
// floor_divide<std::ptrdiff_t>(input.lens()[3] - lengths[1], stride[1]) +
// 1)),
// }};
// }
// else
......
......@@ -15,27 +15,23 @@ argument
pad(hipStream_t stream, argument result, argument arg1, float value, std::vector<std::int64_t> pads)
{
std::size_t nelements = arg1.get_shape().elements();
if(float_equal(value,std::numeric_limits<float>::lowest()))
if(float_equal(value, std::numeric_limits<float>::lowest()))
{
visit_all(result)([&](auto output) {
auto* outptr = device_cast(output.data());
auto val = device_cast(std::numeric_limits<typename
decltype(output)::value_type>::lowest());
gs_launch(stream, nelements)([=](auto i) {
outptr[i] = val;
});
});
auto* outptr = device_cast(output.data());
auto val =
device_cast(std::numeric_limits<typename decltype(output)::value_type>::lowest());
gs_launch(stream, nelements)([=](auto i) { outptr[i] = val; });
});
}
else
{
visit_all(result)([&](auto output) {
auto* outptr = device_cast(output.data());
gs_launch(stream, nelements)([=](auto i) {
outptr[i] = value;
});
});
auto* outptr = device_cast(output.data());
gs_launch(stream, nelements)([=](auto i) { outptr[i] = value; });
});
}
// nary(stream, result)([=] { return value; });
......
......@@ -424,7 +424,7 @@ struct tf_parser
if(pad_mode.find("SAME") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::same;
op.padding_mode = op::padding_mode_t::same;
std::vector<size_t> weight_dims = weights->get_shape().lens();
size_t weight_h = weight_dims[2];
size_t weight_w = weight_dims[3];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment