Commit 3e0496fb authored by Paul Fultz II's avatar Paul Fultz II Committed by mvermeulen
Browse files

Use multi_index for reductions (#400)

* Add functions to do multi-index for local strides as well

* Formatting

* Use same multi-index path for block_reduce

* Formatting

* Use multi-index calc in reduce

* Formatting

* Fix warning

* Fix compiler warning

* Disable some tidy checks
parent 78c83426
...@@ -72,6 +72,7 @@ rocm_enable_clang_tidy( ...@@ -72,6 +72,7 @@ rocm_enable_clang_tidy(
-google-runtime-references -google-runtime-references
-hicpp-braces-around-statements -hicpp-braces-around-statements
-hicpp-explicit-conversions -hicpp-explicit-conversions
-hicpp-member-init
-hicpp-no-array-decay -hicpp-no-array-decay
-hicpp-special-member-functions -hicpp-special-member-functions
-hicpp-uppercase-literal-suffix -hicpp-uppercase-literal-suffix
...@@ -80,9 +81,11 @@ rocm_enable_clang_tidy( ...@@ -80,9 +81,11 @@ rocm_enable_clang_tidy(
-llvm-header-guard -llvm-header-guard
-llvm-include-order -llvm-include-order
-misc-macro-parentheses -misc-macro-parentheses
-modernize-use-override -modernize-concat-nested-namespaces
-modernize-pass-by-value -modernize-pass-by-value
-modernize-use-default-member-init -modernize-use-default-member-init
-modernize-use-nodiscard
-modernize-use-override
-modernize-use-trailing-return-type -modernize-use-trailing-return-type
-modernize-use-transparent-functors -modernize-use-transparent-functors
-performance-type-promotion-in-math-fn -performance-type-promotion-in-math-fn
......
...@@ -207,6 +207,15 @@ auto always(T x) ...@@ -207,6 +207,15 @@ auto always(T x)
return always_f<T>{x}; return always_f<T>{x};
} }
struct id
{
template <class T>
constexpr T operator()(T&& x) const
{
return static_cast<T&&>(x);
}
};
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -64,6 +64,7 @@ add_library(migraphx_device ...@@ -64,6 +64,7 @@ add_library(migraphx_device
set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device) set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device)
rocm_set_soversion(migraphx_device ${PROJECT_VERSION}) rocm_set_soversion(migraphx_device ${PROJECT_VERSION})
rocm_clang_tidy_check(migraphx_device) rocm_clang_tidy_check(migraphx_device)
target_compile_options(migraphx_device PRIVATE -std=c++17)
target_link_libraries(migraphx_device migraphx hip::device -Wno-invalid-command-line-argument -amdgpu-target=gfx803 -amdgpu-target=gfx900 -amdgpu-target=gfx906) target_link_libraries(migraphx_device migraphx hip::device -Wno-invalid-command-line-argument -amdgpu-target=gfx803 -amdgpu-target=gfx900 -amdgpu-target=gfx906)
target_include_directories(migraphx_device PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>) target_include_directories(migraphx_device PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>)
target_include_directories(migraphx_device PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/device/include>) target_include_directories(migraphx_device PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/device/include>)
......
...@@ -61,13 +61,13 @@ inline auto launch(hipStream_t stream, index_int global, index_int local) ...@@ -61,13 +61,13 @@ inline auto launch(hipStream_t stream, index_int global, index_int local)
} }
template <class F> template <class F>
__host__ __device__ auto gs_invoke(F&& f, index_int i, index idx) -> decltype(f(i, idx)) MIGRAPHX_DEVICE_CONSTEXPR auto gs_invoke(F&& f, index_int i, index idx) -> decltype(f(i, idx))
{ {
return f(i, idx); return f(i, idx);
} }
template <class F> template <class F>
__host__ __device__ auto gs_invoke(F&& f, index_int i, index) -> decltype(f(i)) MIGRAPHX_DEVICE_CONSTEXPR auto gs_invoke(F&& f, index_int i, index) -> decltype(f(i))
{ {
return f(i); return f(i);
} }
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/gpu/device/launch.hpp> #include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/shape.hpp> #include <migraphx/gpu/device/shape.hpp>
#include <migraphx/functional.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -17,16 +18,26 @@ struct multi_index ...@@ -17,16 +18,26 @@ struct multi_index
hip_index id{}; hip_index id{};
hip_index stride{}; hip_index stride{};
template <class F> MIGRAPHX_DEVICE_CONSTEXPR auto for_stride(hip_index n) const
MIGRAPHX_DEVICE_CONSTEXPR void for_stride(hip_index n, F f) const
{ {
for(hip_index i = id; i < n; i = n.carry(i + stride)) // f should return void, but this helps with type deduction
{ return [=](auto f) -> decltype(f(hip_index{})) {
f(i); for(hip_index i = id; i < n; i = n.carry(i + stride))
} {
f(i);
}
};
} }
}; };
template <class ForStride>
auto deduce_for_stride(ForStride fs) -> decltype(fs(id{}));
MIGRAPHX_DEVICE_CONSTEXPR multi_index<1> make_multi_index(index_int i, index_int n)
{
return {{i}, {n}};
}
template <index_int N> template <index_int N>
MIGRAPHX_DEVICE_CONSTEXPR multi_index<N> MIGRAPHX_DEVICE_CONSTEXPR multi_index<N>
make_multi_index(const hip_shape<N>& s, index_int i, index_int n) make_multi_index(const hip_shape<N>& s, index_int i, index_int n)
...@@ -42,30 +53,83 @@ make_multi_index(const hip_shape<N>& s, index_int i, const hip_array<index_int, ...@@ -42,30 +53,83 @@ make_multi_index(const hip_shape<N>& s, index_int i, const hip_array<index_int,
} }
template <index_int N> template <index_int N>
inline auto mi_launch(hipStream_t stream, const hip_shape<N>& s, index_int local = 1024) inline auto mi_nglobal(const hip_shape<N>& s, index_int nlocal)
{ {
assert(s.standard); assert(s.standard);
assert(s.elements() > 0); assert(s.elements() > 0);
index_int n = s.elements(); index_int n = s.elements();
index_int groups = (n + local - 1) / local; index_int groups = (n + nlocal - 1) / nlocal;
index_int nglobal = std::min<index_int>(128, groups) * local; index_int nglobal = std::min<index_int>(128, groups) * nlocal;
assert(groups > 0); assert(groups > 0);
assert(nglobal > 0); assert(nglobal > 0);
auto nglobal_multi = s.multi(nglobal); auto nglobal_multi = s.multi(nglobal);
// Skip checking this, since this will cause metadata to not be generated // Skip checking this, since this will cause metadata to not be generated
// for some unknown reason. // for some unknown reason.
// //
// assert(std::any_of(nglobal_multi.begin(), nglobal_multi.end(), [](auto x){return x>0;})); // assert(std::any_of(nglobal_multi.begin(), nglobal_multi.end(), [](auto x){return x>0;}));
return nglobal_multi;
}
template <index_int N>
inline auto mi_nlocal(const hip_shape<N>& s, index_int local)
{
assert(s.standard);
assert(s.elements() > 0);
auto nlocal_multi = s.multi(local);
// Skip checking this, since this will cause metadata to not be generated
// for some unknown reason.
//
// assert(std::any_of(nlocal_multi.begin(), nlocal_multi.end(), [](auto x){return x>0;}));
return nlocal_multi;
}
template <index_int N>
inline auto mi_launch(hipStream_t stream, const hip_shape<N>& global, index_int nlocal = 1024)
{
auto nglobal_multi = mi_nglobal(global, nlocal);
auto nglobal = global.index(nglobal_multi);
return [=](auto f) { return [=](auto f) {
launch(stream, nglobal, local)([=](auto idx) { launch(stream, nglobal, nlocal)([=](auto idx) {
auto midx = make_multi_index(s, idx.global, nglobal_multi); auto midx = make_multi_index(global, idx.global, nglobal_multi);
midx.for_stride(s.lens, [&](auto i) { f(i); }); f(idx, midx.for_stride(global.lens));
}); });
}; };
} }
template <index_int N>
inline auto mi_launch(hipStream_t stream,
const hip_shape<N>& global,
const hip_shape<N>& local,
index_int nlocal = 1024)
{
auto nglobal_multi = mi_nglobal(global, 1);
auto nglobal = global.index(nglobal_multi);
auto nlocal_multi = mi_nlocal(local, nlocal);
return [=](auto f) {
launch(stream, nglobal * nlocal, nlocal)([=](auto idx) {
// TODO: Use fast div for nlocal
auto midx = make_multi_index(global, idx.global / nlocal, nglobal_multi);
auto lidx = make_multi_index(local, idx.local, nlocal_multi);
f(idx, midx.for_stride(global.lens), lidx.for_stride(local.lens));
});
};
}
template <index_int N>
inline auto mi_gs_launch(hipStream_t stream, const hip_shape<N>& global, index_int nlocal = 1024)
{
return [=](auto f) {
mi_launch(stream, global, nlocal)([=](auto, auto g) { g([&](auto i) { f(i); }); });
};
}
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -25,7 +25,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_NARY); ...@@ -25,7 +25,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_NARY);
std::cout << "nary device function: " << __PRETTY_FUNCTION__ << std::endl; std::cout << "nary device function: " << __PRETTY_FUNCTION__ << std::endl;
template <class... Ts> template <class... Ts>
auto pack(Ts... xs) __device__ constexpr auto pack(Ts... xs)
{ {
return [=](auto f) { return f(xs...); }; return [=](auto f) { return f(xs...); };
} }
...@@ -36,7 +36,7 @@ auto nary_nonstandard_nonpacked_impl(hipStream_t stream, F f, argument result, A ...@@ -36,7 +36,7 @@ auto nary_nonstandard_nonpacked_impl(hipStream_t stream, F f, argument result, A
MIGRAPHX_TRACE_NARY_FUNCTION MIGRAPHX_TRACE_NARY_FUNCTION
shape s{result.get_shape().type(), result.get_shape().lens()}; shape s{result.get_shape().type(), result.get_shape().lens()};
hip_visit_all(s, result, args...)([&](auto standard_shape, auto output, auto... inputs) { hip_visit_all(s, result, args...)([&](auto standard_shape, auto output, auto... inputs) {
mi_launch(stream, standard_shape)([=](auto idx) { output[idx] = f(inputs[idx]...); }); mi_gs_launch(stream, standard_shape)([=](auto idx) { output[idx] = f(inputs[idx]...); });
}); });
} }
...@@ -61,10 +61,11 @@ auto nary_nonstandard_packed_impl(hipStream_t stream, ...@@ -61,10 +61,11 @@ auto nary_nonstandard_packed_impl(hipStream_t stream,
auto arg_shape = make_array(args...).front().get_shape(); auto arg_shape = make_array(args...).front().get_shape();
auto perm = find_permutation(arg_shape); auto perm = find_permutation(arg_shape);
auto s = reorder_shape(arg_shape, perm); auto s = reorder_shape(arg_shape, perm);
hip_visit_all(s, result.reshape(reorder_shape(result.get_shape(), perm)), args.reshape(s)...)( hip_visit_all(s,
[&](auto standard_shape, auto output, auto... inputs) { result.reshape(reorder_shape(result.get_shape(), perm)),
mi_launch(stream, standard_shape)([=](auto idx) { output[idx] = f(inputs[idx]...); }); args.reshape(s)...)([&](auto standard_shape, auto output, auto... inputs) {
}); mi_gs_launch(stream, standard_shape)([=](auto idx) { output[idx] = f(inputs[idx]...); });
});
} }
template <class F, class... Arguments> template <class F, class... Arguments>
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <migraphx/gpu/device/launch.hpp> #include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/visit.hpp> #include <migraphx/gpu/device/visit.hpp>
#include <migraphx/gpu/device/multi_index.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -176,13 +177,18 @@ __device__ inline void dpp_reduce(float& x, sum) ...@@ -176,13 +177,18 @@ __device__ inline void dpp_reduce(float& x, sum)
#endif #endif
} }
template <index_int N, class Op, class T, class F> template <index_int N,
__device__ auto block_reduce(index idx, Op op, T init, index_int n, F f) class Op,
class T,
class ForStride,
class F,
MIGRAPHX_REQUIRES(not std::is_integral<ForStride>{})>
__device__ auto block_reduce(index idx, Op op, T init, ForStride fs, F f)
{ {
using type = decltype(f(idx.local)); using type = decltype(f(deduce_for_stride(fs)));
MIGRAPHX_DEVICE_SHARED type buffer[N / 64]; MIGRAPHX_DEVICE_SHARED type buffer[N / 64];
type x = init; type x = init;
idx.local_stride(n, [&](auto i) { x = op(x, f(i)); }); fs([&](auto i) { x = op(x, f(i)); });
dpp_reduce(x, op); dpp_reduce(x, op);
const auto ldsidx = idx.local / 64; const auto ldsidx = idx.local / 64;
...@@ -199,6 +205,18 @@ __device__ auto block_reduce(index idx, Op op, T init, index_int n, F f) ...@@ -199,6 +205,18 @@ __device__ auto block_reduce(index idx, Op op, T init, index_int n, F f)
} }
return y; return y;
} }
template <index_int N, class Op, class T, class F>
__device__ auto block_reduce(index idx, Op op, T init, index_int n, F f)
{
auto midx = make_multi_index(idx.local, idx.nlocal());
// Workaround hcc, create a local array
auto fs = midx.id;
fs[0] = n;
return block_reduce<N>(
idx, op, init, midx.for_stride(fs), [&](auto mi) __device__ { return f(mi[0]); });
}
#endif #endif
constexpr index_int compute_block_size(index_int n, index_int max_block_size) constexpr index_int compute_block_size(index_int n, index_int max_block_size)
{ {
...@@ -219,21 +237,21 @@ void reduce_multi_impl(hipStream_t stream, ...@@ -219,21 +237,21 @@ void reduce_multi_impl(hipStream_t stream,
const shape& reduce_slice) const shape& reduce_slice)
{ {
hip_visit_all(result, arg, reduce_slice)([&](auto output, auto input, auto reduce_shape) { hip_visit_all(result, arg, reduce_slice)([&](auto output, auto input, auto reduce_shape) {
auto nelements = result.get_shape().elements();
auto relements = reduce_slice.elements(); auto relements = reduce_slice.elements();
const index_int max_block_size = 256; const index_int max_block_size = 256;
const index_int block_size = compute_block_size(relements, max_block_size); const index_int block_size = compute_block_size(relements, max_block_size);
gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ { mi_launch(stream, output.get_shape(), reduce_shape, block_size)(
const auto out_idx = i / block_size; [=](auto idx, auto global, auto local) __device__ {
auto base_idx = output.get_shape().multi(out_idx); global([&](auto i) __device__ {
auto r = block_reduce<max_block_size>(idx, op, init, relements, [&](auto j) __device__ { auto r =
auto reduce_idx = reduce_shape.multi(j); block_reduce<max_block_size>(idx, op, init, local, [&](auto j) __device__ {
return read_input(input[reduce_idx + base_idx]); return read_input(input[i + j]);
});
if(idx.local == 0)
output[i] = read_output(r);
});
}); });
if(idx.local == 0)
output.data()[out_idx] = read_output(r);
});
}); });
} }
......
...@@ -10,7 +10,7 @@ namespace gpu { ...@@ -10,7 +10,7 @@ namespace gpu {
namespace device { namespace device {
template <class F> template <class F>
void visit_tensor_size(index_int n, F f) constexpr void visit_tensor_size(index_int n, F f)
{ {
switch(n) switch(n)
{ {
......
...@@ -69,7 +69,7 @@ void int8_gemm_pack_b(hipStream_t stream, const argument& result, const argument ...@@ -69,7 +69,7 @@ void int8_gemm_pack_b(hipStream_t stream, const argument& result, const argument
}); });
} }
void sync_stream(hipStream_t stream) { hipStreamSynchronize(stream); } void sync_stream(hipStream_t stream) { (void)hipStreamSynchronize(stream); }
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment