"git@developer.sourcefind.cn:Wenxuan/LightX2V.git" did not exist on "2e5794c7b60854856e070492ec24ec93a116edab"
Commit b171e9ad authored by Paul's avatar Paul
Browse files

Refactor pad

parent 1599d553
...@@ -13,6 +13,7 @@ template <class T, std::size_t N> ...@@ -13,6 +13,7 @@ template <class T, std::size_t N>
struct hip_tensor_view struct hip_tensor_view
{ {
using value_type = T; using value_type = T;
using hip_index = typename hip_shape<N>::hip_index;
__device__ __host__ hip_tensor_view() = default; __device__ __host__ hip_tensor_view() = default;
__host__ hip_tensor_view(tensor_view<T> x) : d(x.data()), s(x.get_shape()) {} __host__ hip_tensor_view(tensor_view<T> x) : d(x.data()), s(x.get_shape()) {}
__host__ hip_tensor_view(T* x, const shape& ss) : d(x), s(ss) {} __host__ hip_tensor_view(T* x, const shape& ss) : d(x), s(ss) {}
......
...@@ -15,34 +15,26 @@ argument ...@@ -15,34 +15,26 @@ argument
pad(hipStream_t stream, argument result, argument arg1, float value, std::vector<std::int64_t> pads) pad(hipStream_t stream, argument result, argument arg1, float value, std::vector<std::int64_t> pads)
{ {
std::size_t nelements = arg1.get_shape().elements(); std::size_t nelements = arg1.get_shape().elements();
visit_all(result)([&](auto output) { hip_visit_all(result, arg1)([&](auto output, auto input) {
auto* outptr = device_cast(output.data());
using type = typename decltype(output)::value_type; using type = typename decltype(output)::value_type;
device_type<type> device_val = value; using hip_index = typename decltype(output)::hip_index;
type device_val = value;
if(float_equal(value, std::numeric_limits<float>::lowest())) if(float_equal(value, std::numeric_limits<float>::lowest()))
{ {
device_val = device_cast(std::numeric_limits<type>::lowest()); device_val = device_cast(std::numeric_limits<type>::lowest());
} }
gs_launch(stream, result.get_shape().elements())([=](auto i) { outptr[i] = device_val; }); gs_launch(stream, result.get_shape().elements())([=](auto i) { output.data()[i] = device_val; });
});
visit_all(result, arg1)([&](auto output, auto input) { hip_index offsets;
visit_tensor_size(result.get_shape().lens().size(), [&](auto ndim) { std::copy(pads.begin(), pads.begin() + offsets.size(), offsets.begin());
std::size_t offsets[ndim]; gs_launch(stream, nelements)([=](auto i) {
std::copy(pads.begin(), pads.begin() + ndim, offsets); auto idx = input.get_shape().multi(i);
auto* outptr = output.data(); for(std::size_t j = 0; j < offsets.size(); j++)
const auto* inptr = input.data(); {
hip_tensor_descriptor<ndim> desc_input(input.get_shape()); idx[j] += offsets[j];
hip_tensor_descriptor<ndim> desc_output(output.get_shape()); }
gs_launch(stream, nelements)([=](auto i) { output[idx] = input.data()[i];
auto idx = desc_input.multi(i); });
for(std::size_t j = 0; j < ndim; j++)
{
idx[j] += offsets[j];
}
outptr[desc_output.linear(idx)] = inptr[i];
});
});
}); });
return result; return result;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment