Commit a625f7b4 authored by Paul Fultz II's avatar Paul Fultz II Committed by mvermeulen
Browse files

Replace divisions for index calculation with multiplies (#380)

* Implement fast-div for index calculations

* Formatting

* Use fast_div for broadcasts

* Formatting

* Add remiander function

* Compute mult-index using lens instead of strides

* Formatting

* Simplify equation

* Formatting
parent 1398bcc1
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_FAST_DIV_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_FAST_DIV_HPP
#include <migraphx/gpu/device/types.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
constexpr const std::size_t fast_div_shift = 42;
MIGRAPHX_DEVICE_CONSTEXPR std::size_t encode_divisor(std::size_t divisor)
{
if(divisor == 0)
return 0;
auto p = std::size_t{1} << fast_div_shift;
return (p + divisor - 1) / divisor;
}
inline constexpr bool is_divisor_encodable(std::size_t i)
{
return i < std::size_t{1} << (fast_div_shift / 2);
}
MIGRAPHX_DEVICE_CONSTEXPR std::size_t fast_div(std::size_t dividend, std::size_t encoded_divisor)
{
return (dividend * encoded_divisor) >> fast_div_shift;
}
MIGRAPHX_DEVICE_CONSTEXPR std::size_t
remainder(std::size_t result, std::size_t dividend, std::size_t divisor)
{
return dividend - divisor * result;
}
MIGRAPHX_DEVICE_CONSTEXPR std::size_t
fast_mod(std::size_t dividend, std::size_t divisor, std::size_t encoded_divisor)
{
return remainder(fast_div(dividend, encoded_divisor), dividend, divisor);
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -40,6 +40,17 @@ auto nary_nonstandard_nonpacked_impl(hipStream_t stream, F f, argument result, A
});
}
inline auto create_broadcast_index(std::size_t len, std::size_t stride)
{
auto next_stride = stride * len;
auto e_next_stride = encode_divisor(next_stride);
auto e_stride = encode_divisor(stride);
return [=](auto i) {
// ( i % next_stride) / stride
return fast_div(i, e_stride) - len * fast_div(i, e_next_stride);
};
}
template <class F, class... Arguments>
auto nary_nonstandard_packed_impl(hipStream_t stream,
F f,
......@@ -68,9 +79,9 @@ void nary_broadcast_vec_impl(
std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) {
return x != 0;
}));
auto bdim_len = output_shape.lens()[bdim];
auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len;
auto bdim_len = output_shape.lens()[bdim];
auto bdim_stride = output_shape.strides()[bdim];
auto broadcast_idx = create_broadcast_index(bdim_len, bdim_stride);
const std::size_t vec_size = 4;
const std::size_t nlocal = 1024;
......@@ -93,7 +104,7 @@ void nary_broadcast_vec_impl(
// Process the data
for(size_t i = idx.global; i < nelements; i += nglobal)
{
auto bidx = ((i * vec_size) % bdim_next_stride) / bdim_stride;
auto bidx = broadcast_idx(i * vec_size);
auto b = bp[bidx];
auto out = output.data()[i];
for(std::size_t j = 0; j < vec_size; j++)
......@@ -117,9 +128,9 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg
std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) {
return x != 0;
}));
auto bdim_len = output_shape.lens()[bdim];
auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len;
auto bdim_len = output_shape.lens()[bdim];
auto bdim_stride = output_shape.strides()[bdim];
auto broadcast_idx = create_broadcast_index(bdim_len, bdim_stride);
const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal;
......@@ -137,7 +148,7 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg
// Process the data
for(size_t i = idx.global; i < nelements; i += nglobal)
{
auto bidx = (i % bdim_next_stride) / bdim_stride;
auto bidx = broadcast_idx(i);
auto b = buffer[bidx];
output.data()[i] = f(inputs.data()[i]..., b);
}
......@@ -160,9 +171,9 @@ void nary_double_broadcast_vec_impl(
std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) {
return x != 0;
}));
auto bdim_len = output_shape.lens()[bdim];
auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len;
auto bdim_len = output_shape.lens()[bdim];
auto bdim_stride = output_shape.strides()[bdim];
auto broadcast_idx = create_broadcast_index(bdim_len, bdim_stride);
const std::size_t vec_size = 4;
const std::size_t nlocal = 1024;
......@@ -189,7 +200,7 @@ void nary_double_broadcast_vec_impl(
// Process the data
for(size_t i = idx.global; i < nelements; i += nglobal)
{
auto bidx = ((i * vec_size) % bdim_next_stride) / bdim_stride;
auto bidx = broadcast_idx(i * vec_size);
auto b1 = bp[bidx];
auto b2 = bp[bidx + bdim_len];
auto out = output.data()[i];
......@@ -218,9 +229,9 @@ void nary_double_broadcast_impl(
std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) {
return x != 0;
}));
auto bdim_len = output_shape.lens()[bdim];
auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len;
auto bdim_len = output_shape.lens()[bdim];
auto bdim_stride = output_shape.strides()[bdim];
auto broadcast_idx = create_broadcast_index(bdim_len, bdim_stride);
const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal;
......@@ -243,7 +254,7 @@ void nary_double_broadcast_impl(
// Process the data
for(size_t i = idx.global; i < nelements; i += nglobal)
{
auto bidx = (i % bdim_next_stride) / bdim_stride;
auto bidx = broadcast_idx(i);
auto b1 = buffer[bidx];
auto b2 = buffer[bidx + bdim_len];
output.data()[i] = f(inputs.data()[i]..., b2, b1);
......
......@@ -3,6 +3,7 @@
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_SHAPE_HPP
#include <migraphx/gpu/device/array.hpp>
#include <migraphx/gpu/device/fast_div.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -15,6 +16,7 @@ struct hip_shape
using hip_index = hip_array<std::size_t, N>;
hip_array<std::size_t, N> lens = {};
hip_array<std::size_t, N> strides = {};
hip_array<std::size_t, N> divs = {};
bool standard = false;
__device__ __host__ hip_shape() = default;
......@@ -25,6 +27,8 @@ struct hip_shape
assert(s.strides().size() == N);
std::copy(s.lens().begin(), s.lens().end(), lens.begin());
std::copy(s.strides().begin(), s.strides().end(), strides.begin());
assert(std::all_of(s.lens().begin(), s.lens().end(), &is_divisor_encodable));
std::transform(s.lens().begin(), s.lens().end(), divs.begin(), &encode_divisor);
}
MIGRAPHX_DEVICE_CONSTEXPR std::size_t elements() const { return lens.product(); }
......@@ -66,10 +70,13 @@ struct hip_shape
{
hip_index result;
std::size_t tidx = idx;
for(std::size_t is = 0; is < result.size(); is++)
for(std::ptrdiff_t is = result.size() - 1; is >= 0; is--)
{
result[is] = tidx / strides[is];
tidx = tidx % strides[is];
// result[is] = tidx % lens[is];
// tidx = tdix / lens[is];
auto q = fast_div(tidx, divs[is]);
result[is] = remainder(q, tidx, lens[is]);
tidx = q;
}
return result;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment