Commit fd491ac1 authored by Khalique's avatar Khalique
Browse files

bugs in device pad

parent 877b8c44
...@@ -17,21 +17,28 @@ pad(hipStream_t stream, argument result, argument arg1, float value, std::vector ...@@ -17,21 +17,28 @@ pad(hipStream_t stream, argument result, argument arg1, float value, std::vector
std::size_t nelements = arg1.get_shape().elements(); std::size_t nelements = arg1.get_shape().elements();
if(float_equal(value, std::numeric_limits<float>::lowest())) if(float_equal(value, std::numeric_limits<float>::lowest()))
{ {
visit_all(result)([&](auto output) { auto val =
auto* outptr = device_cast(output.data()); device_cast(std::numeric_limits<decltype(value)>::lowest());
auto val = nary(stream, result)([=] { return val; });
device_cast(std::numeric_limits<typename decltype(output)::value_type>::lowest()); // visit_all(result)([&](auto output) {
// auto* outptr = device_cast(output.data());
// auto val =
// device_cast(std::numeric_limits<typename decltype(output)::value_type>::lowest());
gs_launch(stream, nelements)([=](auto i) { outptr[i] = val; }); // gs_launch(stream, nelements)([=](auto i) { outptr[i] = val; });
}); // });
} }
else else
{ {
visit_all(result)([&](auto output) { // visit_all(result)([&](auto output) {
auto* outptr = device_cast(output.data()); // auto* outptr = device_cast(output.data());
gs_launch(stream, nelements)([=](auto i) { outptr[i] = value; }); // auto val =
}); // device_cast(value);
// gs_launch(stream, nelements)([=](auto i) { outptr[i] = val; });
// });'
auto val = device_cast(value);
nary(stream, result)([=] { return val; });
} }
// nary(stream, result)([=] { return value; }); // nary(stream, result)([=] { return value; });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment