Commit 3e0496fb authored by Paul Fultz II's avatar Paul Fultz II Committed by mvermeulen
Browse files

Use multi_index for reductions (#400)

* Add functions to do multi-index for local strides as well

* Formatting

* Use same multi-index path for block_reduce

* Formatting

* Use multi-index calc in reduce

* Formatting

* Fix warning

* Fix compiler warning

* Disable some tidy checks
parent 78c83426
......@@ -72,6 +72,7 @@ rocm_enable_clang_tidy(
-google-runtime-references
-hicpp-braces-around-statements
-hicpp-explicit-conversions
-hicpp-member-init
-hicpp-no-array-decay
-hicpp-special-member-functions
-hicpp-uppercase-literal-suffix
......@@ -80,9 +81,11 @@ rocm_enable_clang_tidy(
-llvm-header-guard
-llvm-include-order
-misc-macro-parentheses
-modernize-use-override
-modernize-concat-nested-namespaces
-modernize-pass-by-value
-modernize-use-default-member-init
-modernize-use-nodiscard
-modernize-use-override
-modernize-use-trailing-return-type
-modernize-use-transparent-functors
-performance-type-promotion-in-math-fn
......
......@@ -207,6 +207,15 @@ auto always(T x)
return always_f<T>{x};
}
struct id
{
template <class T>
constexpr T operator()(T&& x) const
{
return static_cast<T&&>(x);
}
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -64,6 +64,7 @@ add_library(migraphx_device
set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device)
rocm_set_soversion(migraphx_device ${PROJECT_VERSION})
rocm_clang_tidy_check(migraphx_device)
target_compile_options(migraphx_device PRIVATE -std=c++17)
target_link_libraries(migraphx_device migraphx hip::device -Wno-invalid-command-line-argument -amdgpu-target=gfx803 -amdgpu-target=gfx900 -amdgpu-target=gfx906)
target_include_directories(migraphx_device PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>)
target_include_directories(migraphx_device PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/device/include>)
......
......@@ -61,13 +61,13 @@ inline auto launch(hipStream_t stream, index_int global, index_int local)
}
template <class F>
__host__ __device__ auto gs_invoke(F&& f, index_int i, index idx) -> decltype(f(i, idx))
MIGRAPHX_DEVICE_CONSTEXPR auto gs_invoke(F&& f, index_int i, index idx) -> decltype(f(i, idx))
{
return f(i, idx);
}
template <class F>
__host__ __device__ auto gs_invoke(F&& f, index_int i, index) -> decltype(f(i))
MIGRAPHX_DEVICE_CONSTEXPR auto gs_invoke(F&& f, index_int i, index) -> decltype(f(i))
{
return f(i);
}
......
......@@ -4,6 +4,7 @@
#include <migraphx/config.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/shape.hpp>
#include <migraphx/functional.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -17,16 +18,26 @@ struct multi_index
hip_index id{};
hip_index stride{};
template <class F>
MIGRAPHX_DEVICE_CONSTEXPR void for_stride(hip_index n, F f) const
MIGRAPHX_DEVICE_CONSTEXPR auto for_stride(hip_index n) const
{
for(hip_index i = id; i < n; i = n.carry(i + stride))
{
f(i);
}
// f should return void, but this helps with type deduction
return [=](auto f) -> decltype(f(hip_index{})) {
for(hip_index i = id; i < n; i = n.carry(i + stride))
{
f(i);
}
};
}
};
template <class ForStride>
auto deduce_for_stride(ForStride fs) -> decltype(fs(id{}));
MIGRAPHX_DEVICE_CONSTEXPR multi_index<1> make_multi_index(index_int i, index_int n)
{
return {{i}, {n}};
}
template <index_int N>
MIGRAPHX_DEVICE_CONSTEXPR multi_index<N>
make_multi_index(const hip_shape<N>& s, index_int i, index_int n)
......@@ -42,30 +53,83 @@ make_multi_index(const hip_shape<N>& s, index_int i, const hip_array<index_int,
}
template <index_int N>
inline auto mi_launch(hipStream_t stream, const hip_shape<N>& s, index_int local = 1024)
inline auto mi_nglobal(const hip_shape<N>& s, index_int nlocal)
{
assert(s.standard);
assert(s.elements() > 0);
index_int n = s.elements();
index_int groups = (n + local - 1) / local;
index_int nglobal = std::min<index_int>(128, groups) * local;
index_int groups = (n + nlocal - 1) / nlocal;
index_int nglobal = std::min<index_int>(128, groups) * nlocal;
assert(groups > 0);
assert(nglobal > 0);
auto nglobal_multi = s.multi(nglobal);
// Skip checking this, since this will cause metadata to not be generated
// for some unknown reason.
//
// assert(std::any_of(nglobal_multi.begin(), nglobal_multi.end(), [](auto x){return x>0;}));
return nglobal_multi;
}
template <index_int N>
inline auto mi_nlocal(const hip_shape<N>& s, index_int local)
{
assert(s.standard);
assert(s.elements() > 0);
auto nlocal_multi = s.multi(local);
// Skip checking this, since this will cause metadata to not be generated
// for some unknown reason.
//
// assert(std::any_of(nlocal_multi.begin(), nlocal_multi.end(), [](auto x){return x>0;}));
return nlocal_multi;
}
template <index_int N>
inline auto mi_launch(hipStream_t stream, const hip_shape<N>& global, index_int nlocal = 1024)
{
auto nglobal_multi = mi_nglobal(global, nlocal);
auto nglobal = global.index(nglobal_multi);
return [=](auto f) {
launch(stream, nglobal, local)([=](auto idx) {
auto midx = make_multi_index(s, idx.global, nglobal_multi);
midx.for_stride(s.lens, [&](auto i) { f(i); });
launch(stream, nglobal, nlocal)([=](auto idx) {
auto midx = make_multi_index(global, idx.global, nglobal_multi);
f(idx, midx.for_stride(global.lens));
});
};
}
template <index_int N>
inline auto mi_launch(hipStream_t stream,
const hip_shape<N>& global,
const hip_shape<N>& local,
index_int nlocal = 1024)
{
auto nglobal_multi = mi_nglobal(global, 1);
auto nglobal = global.index(nglobal_multi);
auto nlocal_multi = mi_nlocal(local, nlocal);
return [=](auto f) {
launch(stream, nglobal * nlocal, nlocal)([=](auto idx) {
// TODO: Use fast div for nlocal
auto midx = make_multi_index(global, idx.global / nlocal, nglobal_multi);
auto lidx = make_multi_index(local, idx.local, nlocal_multi);
f(idx, midx.for_stride(global.lens), lidx.for_stride(local.lens));
});
};
}
template <index_int N>
inline auto mi_gs_launch(hipStream_t stream, const hip_shape<N>& global, index_int nlocal = 1024)
{
return [=](auto f) {
mi_launch(stream, global, nlocal)([=](auto, auto g) { g([&](auto i) { f(i); }); });
};
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -25,7 +25,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_NARY);
std::cout << "nary device function: " << __PRETTY_FUNCTION__ << std::endl;
template <class... Ts>
auto pack(Ts... xs) __device__
constexpr auto pack(Ts... xs)
{
return [=](auto f) { return f(xs...); };
}
......@@ -36,7 +36,7 @@ auto nary_nonstandard_nonpacked_impl(hipStream_t stream, F f, argument result, A
MIGRAPHX_TRACE_NARY_FUNCTION
shape s{result.get_shape().type(), result.get_shape().lens()};
hip_visit_all(s, result, args...)([&](auto standard_shape, auto output, auto... inputs) {
mi_launch(stream, standard_shape)([=](auto idx) { output[idx] = f(inputs[idx]...); });
mi_gs_launch(stream, standard_shape)([=](auto idx) { output[idx] = f(inputs[idx]...); });
});
}
......@@ -61,10 +61,11 @@ auto nary_nonstandard_packed_impl(hipStream_t stream,
auto arg_shape = make_array(args...).front().get_shape();
auto perm = find_permutation(arg_shape);
auto s = reorder_shape(arg_shape, perm);
hip_visit_all(s, result.reshape(reorder_shape(result.get_shape(), perm)), args.reshape(s)...)(
[&](auto standard_shape, auto output, auto... inputs) {
mi_launch(stream, standard_shape)([=](auto idx) { output[idx] = f(inputs[idx]...); });
});
hip_visit_all(s,
result.reshape(reorder_shape(result.get_shape(), perm)),
args.reshape(s)...)([&](auto standard_shape, auto output, auto... inputs) {
mi_gs_launch(stream, standard_shape)([=](auto idx) { output[idx] = f(inputs[idx]...); });
});
}
template <class F, class... Arguments>
......
......@@ -4,6 +4,7 @@
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/visit.hpp>
#include <migraphx/gpu/device/multi_index.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -176,13 +177,18 @@ __device__ inline void dpp_reduce(float& x, sum)
#endif
}
template <index_int N, class Op, class T, class F>
__device__ auto block_reduce(index idx, Op op, T init, index_int n, F f)
template <index_int N,
class Op,
class T,
class ForStride,
class F,
MIGRAPHX_REQUIRES(not std::is_integral<ForStride>{})>
__device__ auto block_reduce(index idx, Op op, T init, ForStride fs, F f)
{
using type = decltype(f(idx.local));
using type = decltype(f(deduce_for_stride(fs)));
MIGRAPHX_DEVICE_SHARED type buffer[N / 64];
type x = init;
idx.local_stride(n, [&](auto i) { x = op(x, f(i)); });
fs([&](auto i) { x = op(x, f(i)); });
dpp_reduce(x, op);
const auto ldsidx = idx.local / 64;
......@@ -199,6 +205,18 @@ __device__ auto block_reduce(index idx, Op op, T init, index_int n, F f)
}
return y;
}
template <index_int N, class Op, class T, class F>
__device__ auto block_reduce(index idx, Op op, T init, index_int n, F f)
{
auto midx = make_multi_index(idx.local, idx.nlocal());
// Workaround hcc, create a local array
auto fs = midx.id;
fs[0] = n;
return block_reduce<N>(
idx, op, init, midx.for_stride(fs), [&](auto mi) __device__ { return f(mi[0]); });
}
#endif
constexpr index_int compute_block_size(index_int n, index_int max_block_size)
{
......@@ -219,21 +237,21 @@ void reduce_multi_impl(hipStream_t stream,
const shape& reduce_slice)
{
hip_visit_all(result, arg, reduce_slice)([&](auto output, auto input, auto reduce_shape) {
auto nelements = result.get_shape().elements();
auto relements = reduce_slice.elements();
const index_int max_block_size = 256;
const index_int block_size = compute_block_size(relements, max_block_size);
gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ {
const auto out_idx = i / block_size;
auto base_idx = output.get_shape().multi(out_idx);
auto r = block_reduce<max_block_size>(idx, op, init, relements, [&](auto j) __device__ {
auto reduce_idx = reduce_shape.multi(j);
return read_input(input[reduce_idx + base_idx]);
mi_launch(stream, output.get_shape(), reduce_shape, block_size)(
[=](auto idx, auto global, auto local) __device__ {
global([&](auto i) __device__ {
auto r =
block_reduce<max_block_size>(idx, op, init, local, [&](auto j) __device__ {
return read_input(input[i + j]);
});
if(idx.local == 0)
output[i] = read_output(r);
});
});
if(idx.local == 0)
output.data()[out_idx] = read_output(r);
});
});
}
......
......@@ -10,7 +10,7 @@ namespace gpu {
namespace device {
template <class F>
void visit_tensor_size(index_int n, F f)
constexpr void visit_tensor_size(index_int n, F f)
{
switch(n)
{
......
......@@ -69,7 +69,7 @@ void int8_gemm_pack_b(hipStream_t stream, const argument& result, const argument
});
}
void sync_stream(hipStream_t stream) { hipStreamSynchronize(stream); }
void sync_stream(hipStream_t stream) { (void)hipStreamSynchronize(stream); }
} // namespace device
} // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment