Commit 877b8c44 authored by Khalique's avatar Khalique
Browse files

formatting

parent 4f5024b7
...@@ -46,23 +46,23 @@ struct convolution ...@@ -46,23 +46,23 @@ struct convolution
auto t = input.type(); auto t = input.type();
// if(padding_mode == default_) // if(padding_mode == default_)
// { // {
return {t, return {t,
{ {
input.lens()[0], input.lens()[0],
weights.lens()[0], weights.lens()[0],
std::size_t(std::max<std::ptrdiff_t>( std::size_t(std::max<std::ptrdiff_t>(
1, 1,
(input.lens()[2] - (1 + dilation[0] * (weights.lens()[2] - 1)) + (input.lens()[2] - (1 + dilation[0] * (weights.lens()[2] - 1)) +
2 * padding[0]) / 2 * padding[0]) /
stride[0] + stride[0] +
1)), 1)),
std::size_t(std::max<std::ptrdiff_t>( std::size_t(std::max<std::ptrdiff_t>(
1, 1,
(input.lens()[3] - (1 + dilation[1] * (weights.lens()[3] - 1)) + (input.lens()[3] - (1 + dilation[1] * (weights.lens()[3] - 1)) +
2 * padding[1]) / 2 * padding[1]) /
stride[1] + stride[1] +
1)), 1)),
}}; }};
// } // }
// else if(padding_mode == same) // else if(padding_mode == same)
// { // {
...@@ -83,7 +83,8 @@ struct convolution ...@@ -83,7 +83,8 @@ struct convolution
// static_cast<std::size_t>(std::ceil( // static_cast<std::size_t>(std::ceil(
// static_cast<double>(input.lens()[2] - weights.lens()[2] + 1) / stride[0])), // static_cast<double>(input.lens()[2] - weights.lens()[2] + 1) / stride[0])),
// static_cast<std::size_t>(std::ceil( // static_cast<std::size_t>(std::ceil(
// static_cast<double>(input.lens()[3] - weights.lens()[3] + 1) / stride[1]))}}; // static_cast<double>(input.lens()[3] - weights.lens()[3] + 1) /
// stride[1]))}};
// } // }
// else // else
// { // {
......
...@@ -50,21 +50,21 @@ struct pooling ...@@ -50,21 +50,21 @@ struct pooling
// if(padding_mode == default_) // if(padding_mode == default_)
// { // {
return {t, return {t,
{ {
input.lens()[0], input.lens()[0],
input.lens()[1], input.lens()[1],
std::size_t(std::max<std::ptrdiff_t>( std::size_t(std::max<std::ptrdiff_t>(
1, 1,
floor_divide<std::ptrdiff_t>( floor_divide<std::ptrdiff_t>(input.lens()[2] + 2 * padding[0] - lengths[0],
input.lens()[2] + 2 * padding[0] - lengths[0], stride[0]) + stride[0]) +
1)), 1)),
std::size_t(std::max<std::ptrdiff_t>( std::size_t(std::max<std::ptrdiff_t>(
1, 1,
floor_divide<std::ptrdiff_t>( floor_divide<std::ptrdiff_t>(input.lens()[3] + 2 * padding[1] - lengths[1],
input.lens()[3] + 2 * padding[1] - lengths[1], stride[1]) + stride[1]) +
1)), 1)),
}}; }};
// } // }
// else if(padding_mode == same) // else if(padding_mode == same)
// { // {
...@@ -83,10 +83,12 @@ struct pooling ...@@ -83,10 +83,12 @@ struct pooling
// input.lens()[1], // input.lens()[1],
// std::size_t(std::max<std::ptrdiff_t>( // std::size_t(std::max<std::ptrdiff_t>(
// 1, // 1,
// floor_divide<std::ptrdiff_t>(input.lens()[2] - lengths[0], stride[0]) + 1)), // floor_divide<std::ptrdiff_t>(input.lens()[2] - lengths[0], stride[0]) +
// 1)),
// std::size_t(std::max<std::ptrdiff_t>( // std::size_t(std::max<std::ptrdiff_t>(
// 1, // 1,
// floor_divide<std::ptrdiff_t>(input.lens()[3] - lengths[1], stride[1]) + 1)), // floor_divide<std::ptrdiff_t>(input.lens()[3] - lengths[1], stride[1]) +
// 1)),
// }}; // }};
// } // }
// else // else
......
...@@ -15,27 +15,23 @@ argument ...@@ -15,27 +15,23 @@ argument
pad(hipStream_t stream, argument result, argument arg1, float value, std::vector<std::int64_t> pads) pad(hipStream_t stream, argument result, argument arg1, float value, std::vector<std::int64_t> pads)
{ {
std::size_t nelements = arg1.get_shape().elements(); std::size_t nelements = arg1.get_shape().elements();
if(float_equal(value,std::numeric_limits<float>::lowest())) if(float_equal(value, std::numeric_limits<float>::lowest()))
{ {
visit_all(result)([&](auto output) { visit_all(result)([&](auto output) {
auto* outptr = device_cast(output.data()); auto* outptr = device_cast(output.data());
auto val = device_cast(std::numeric_limits<typename auto val =
decltype(output)::value_type>::lowest()); device_cast(std::numeric_limits<typename decltype(output)::value_type>::lowest());
gs_launch(stream, nelements)([=](auto i) { gs_launch(stream, nelements)([=](auto i) { outptr[i] = val; });
outptr[i] = val; });
});
});
} }
else else
{ {
visit_all(result)([&](auto output) { visit_all(result)([&](auto output) {
auto* outptr = device_cast(output.data()); auto* outptr = device_cast(output.data());
gs_launch(stream, nelements)([=](auto i) { gs_launch(stream, nelements)([=](auto i) { outptr[i] = value; });
outptr[i] = value; });
});
});
} }
// nary(stream, result)([=] { return value; }); // nary(stream, result)([=] { return value; });
......
...@@ -424,7 +424,7 @@ struct tf_parser ...@@ -424,7 +424,7 @@ struct tf_parser
if(pad_mode.find("SAME") != std::string::npos) if(pad_mode.find("SAME") != std::string::npos)
{ {
op.padding_mode = op::padding_mode_t::same; op.padding_mode = op::padding_mode_t::same;
std::vector<size_t> weight_dims = weights->get_shape().lens(); std::vector<size_t> weight_dims = weights->get_shape().lens();
size_t weight_h = weight_dims[2]; size_t weight_h = weight_dims[2];
size_t weight_w = weight_dims[3]; size_t weight_w = weight_dims[3];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment