"src/include/blockwise_4d_tensor_op.hpp" did not exist on "8a4b59785b4f5ba48468d53618ca270c5da599a7"
Commit a625f7b4 authored by Paul Fultz II's avatar Paul Fultz II Committed by mvermeulen
Browse files

Replace divisions for index calculation with multiplies (#380)

* Implement fast-div for index calculations

* Formatting

* Use fast_div for broadcasts

* Formatting

* Add remiander function

* Compute mult-index using lens instead of strides

* Formatting

* Simplify equation

* Formatting
parent 1398bcc1
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_FAST_DIV_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_FAST_DIV_HPP
#include <migraphx/gpu/device/types.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
constexpr const std::size_t fast_div_shift = 42;
MIGRAPHX_DEVICE_CONSTEXPR std::size_t encode_divisor(std::size_t divisor)
{
if(divisor == 0)
return 0;
auto p = std::size_t{1} << fast_div_shift;
return (p + divisor - 1) / divisor;
}
inline constexpr bool is_divisor_encodable(std::size_t i)
{
return i < std::size_t{1} << (fast_div_shift / 2);
}
MIGRAPHX_DEVICE_CONSTEXPR std::size_t fast_div(std::size_t dividend, std::size_t encoded_divisor)
{
return (dividend * encoded_divisor) >> fast_div_shift;
}
MIGRAPHX_DEVICE_CONSTEXPR std::size_t
remainder(std::size_t result, std::size_t dividend, std::size_t divisor)
{
return dividend - divisor * result;
}
MIGRAPHX_DEVICE_CONSTEXPR std::size_t
fast_mod(std::size_t dividend, std::size_t divisor, std::size_t encoded_divisor)
{
return remainder(fast_div(dividend, encoded_divisor), dividend, divisor);
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -40,6 +40,17 @@ auto nary_nonstandard_nonpacked_impl(hipStream_t stream, F f, argument result, A ...@@ -40,6 +40,17 @@ auto nary_nonstandard_nonpacked_impl(hipStream_t stream, F f, argument result, A
}); });
} }
inline auto create_broadcast_index(std::size_t len, std::size_t stride)
{
auto next_stride = stride * len;
auto e_next_stride = encode_divisor(next_stride);
auto e_stride = encode_divisor(stride);
return [=](auto i) {
// ( i % next_stride) / stride
return fast_div(i, e_stride) - len * fast_div(i, e_next_stride);
};
}
template <class F, class... Arguments> template <class F, class... Arguments>
auto nary_nonstandard_packed_impl(hipStream_t stream, auto nary_nonstandard_packed_impl(hipStream_t stream,
F f, F f,
...@@ -68,9 +79,9 @@ void nary_broadcast_vec_impl( ...@@ -68,9 +79,9 @@ void nary_broadcast_vec_impl(
std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) { std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) {
return x != 0; return x != 0;
})); }));
auto bdim_len = output_shape.lens()[bdim]; auto bdim_len = output_shape.lens()[bdim];
auto bdim_stride = output_shape.strides()[bdim]; auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len; auto broadcast_idx = create_broadcast_index(bdim_len, bdim_stride);
const std::size_t vec_size = 4; const std::size_t vec_size = 4;
const std::size_t nlocal = 1024; const std::size_t nlocal = 1024;
...@@ -93,7 +104,7 @@ void nary_broadcast_vec_impl( ...@@ -93,7 +104,7 @@ void nary_broadcast_vec_impl(
// Process the data // Process the data
for(size_t i = idx.global; i < nelements; i += nglobal) for(size_t i = idx.global; i < nelements; i += nglobal)
{ {
auto bidx = ((i * vec_size) % bdim_next_stride) / bdim_stride; auto bidx = broadcast_idx(i * vec_size);
auto b = bp[bidx]; auto b = bp[bidx];
auto out = output.data()[i]; auto out = output.data()[i];
for(std::size_t j = 0; j < vec_size; j++) for(std::size_t j = 0; j < vec_size; j++)
...@@ -117,9 +128,9 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg ...@@ -117,9 +128,9 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg
std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) { std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) {
return x != 0; return x != 0;
})); }));
auto bdim_len = output_shape.lens()[bdim]; auto bdim_len = output_shape.lens()[bdim];
auto bdim_stride = output_shape.strides()[bdim]; auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len; auto broadcast_idx = create_broadcast_index(bdim_len, bdim_stride);
const std::size_t nlocal = 1024; const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal; const std::size_t nglobal = 256 * nlocal;
...@@ -137,7 +148,7 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg ...@@ -137,7 +148,7 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg
// Process the data // Process the data
for(size_t i = idx.global; i < nelements; i += nglobal) for(size_t i = idx.global; i < nelements; i += nglobal)
{ {
auto bidx = (i % bdim_next_stride) / bdim_stride; auto bidx = broadcast_idx(i);
auto b = buffer[bidx]; auto b = buffer[bidx];
output.data()[i] = f(inputs.data()[i]..., b); output.data()[i] = f(inputs.data()[i]..., b);
} }
...@@ -160,9 +171,9 @@ void nary_double_broadcast_vec_impl( ...@@ -160,9 +171,9 @@ void nary_double_broadcast_vec_impl(
std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) { std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) {
return x != 0; return x != 0;
})); }));
auto bdim_len = output_shape.lens()[bdim]; auto bdim_len = output_shape.lens()[bdim];
auto bdim_stride = output_shape.strides()[bdim]; auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len; auto broadcast_idx = create_broadcast_index(bdim_len, bdim_stride);
const std::size_t vec_size = 4; const std::size_t vec_size = 4;
const std::size_t nlocal = 1024; const std::size_t nlocal = 1024;
...@@ -189,7 +200,7 @@ void nary_double_broadcast_vec_impl( ...@@ -189,7 +200,7 @@ void nary_double_broadcast_vec_impl(
// Process the data // Process the data
for(size_t i = idx.global; i < nelements; i += nglobal) for(size_t i = idx.global; i < nelements; i += nglobal)
{ {
auto bidx = ((i * vec_size) % bdim_next_stride) / bdim_stride; auto bidx = broadcast_idx(i * vec_size);
auto b1 = bp[bidx]; auto b1 = bp[bidx];
auto b2 = bp[bidx + bdim_len]; auto b2 = bp[bidx + bdim_len];
auto out = output.data()[i]; auto out = output.data()[i];
...@@ -218,9 +229,9 @@ void nary_double_broadcast_impl( ...@@ -218,9 +229,9 @@ void nary_double_broadcast_impl(
std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) { std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) {
return x != 0; return x != 0;
})); }));
auto bdim_len = output_shape.lens()[bdim]; auto bdim_len = output_shape.lens()[bdim];
auto bdim_stride = output_shape.strides()[bdim]; auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len; auto broadcast_idx = create_broadcast_index(bdim_len, bdim_stride);
const std::size_t nlocal = 1024; const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal; const std::size_t nglobal = 256 * nlocal;
...@@ -243,7 +254,7 @@ void nary_double_broadcast_impl( ...@@ -243,7 +254,7 @@ void nary_double_broadcast_impl(
// Process the data // Process the data
for(size_t i = idx.global; i < nelements; i += nglobal) for(size_t i = idx.global; i < nelements; i += nglobal)
{ {
auto bidx = (i % bdim_next_stride) / bdim_stride; auto bidx = broadcast_idx(i);
auto b1 = buffer[bidx]; auto b1 = buffer[bidx];
auto b2 = buffer[bidx + bdim_len]; auto b2 = buffer[bidx + bdim_len];
output.data()[i] = f(inputs.data()[i]..., b2, b1); output.data()[i] = f(inputs.data()[i]..., b2, b1);
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_SHAPE_HPP #define MIGRAPHX_GUARD_RTGLIB_DEVICE_SHAPE_HPP
#include <migraphx/gpu/device/array.hpp> #include <migraphx/gpu/device/array.hpp>
#include <migraphx/gpu/device/fast_div.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -15,6 +16,7 @@ struct hip_shape ...@@ -15,6 +16,7 @@ struct hip_shape
using hip_index = hip_array<std::size_t, N>; using hip_index = hip_array<std::size_t, N>;
hip_array<std::size_t, N> lens = {}; hip_array<std::size_t, N> lens = {};
hip_array<std::size_t, N> strides = {}; hip_array<std::size_t, N> strides = {};
hip_array<std::size_t, N> divs = {};
bool standard = false; bool standard = false;
__device__ __host__ hip_shape() = default; __device__ __host__ hip_shape() = default;
...@@ -25,6 +27,8 @@ struct hip_shape ...@@ -25,6 +27,8 @@ struct hip_shape
assert(s.strides().size() == N); assert(s.strides().size() == N);
std::copy(s.lens().begin(), s.lens().end(), lens.begin()); std::copy(s.lens().begin(), s.lens().end(), lens.begin());
std::copy(s.strides().begin(), s.strides().end(), strides.begin()); std::copy(s.strides().begin(), s.strides().end(), strides.begin());
assert(std::all_of(s.lens().begin(), s.lens().end(), &is_divisor_encodable));
std::transform(s.lens().begin(), s.lens().end(), divs.begin(), &encode_divisor);
} }
MIGRAPHX_DEVICE_CONSTEXPR std::size_t elements() const { return lens.product(); } MIGRAPHX_DEVICE_CONSTEXPR std::size_t elements() const { return lens.product(); }
...@@ -66,10 +70,13 @@ struct hip_shape ...@@ -66,10 +70,13 @@ struct hip_shape
{ {
hip_index result; hip_index result;
std::size_t tidx = idx; std::size_t tidx = idx;
for(std::size_t is = 0; is < result.size(); is++) for(std::ptrdiff_t is = result.size() - 1; is >= 0; is--)
{ {
result[is] = tidx / strides[is]; // result[is] = tidx % lens[is];
tidx = tidx % strides[is]; // tidx = tdix / lens[is];
auto q = fast_div(tidx, divs[is]);
result[is] = remainder(q, tidx, lens[is]);
tidx = q;
} }
return result; return result;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment