Unverified Commit f9a5b81e authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Reduce with runtime compilation (#1150)

There is significant improvement on larger tensors with half almost 50% faster:

lens: [1024, 384, 768]
gpu::code_object[code_object=13832,symbol_name=kernel,global=39321600,local=256,]: 1.16685ms
gpu::reduce_sum[axes={2}]: 1.73126ms
Also for non-trivial layouts this can sometimes be over 2x faster:

lens: [64, 1024, 768, 4]
gpu::code_object[code_object=13832,symbol_name=kernel,global=39321600,local=256,]: 1.1706ms
gpu::reduce_sum[axes={1}]: 2.63375ms
Of course if the stride becomes larger this speed improvement diminishes due to poor memory access patterns. A lane_reduce instead of a block_reduce is needed for such type of kernels. I plan to address that in a future PR.

Finally, this also includes a MIGRAPHX_GPU_DUMP_ASM env variable which will print out the assembly when the kernel compiles.
parent 12007dba
......@@ -4,7 +4,7 @@ CheckOptions:
- key: bugprone-unused-return-value.CheckedFunctions
value: '::std::async;::std::launder;::std::remove;::std::remove_if;::std::unique;::std::unique_ptr::release;::std::basic_string::empty;::std::vector::empty;::std::find;::std::find_if;::std::find_if_not;::std::all_of;::std::any_of;::std::none_of;::std::count;::std::count_if;::std::mismatch;::std::find_end;::std::find_first_of;::std::adjacent_find;::std::search;::std::search_n;::std::nth_element;::std::lower_bound;::std::upper_bound;::std::binary_search;::std::equal_range;::std::max;::std::max_element;::std::min;::std::min_element;::std::minmax;::std::minmax_element;::std::equal;::std::lexicographical_compare;::std::accumulate;::std::inner_product'
- key: cppcoreguidelines-macro-usage.AllowedRegexp
value: 'DEBUG|FALLTHROUGH|STRINGIZE|_HAS_|_THROW|_REQUIRES|_DECLARE_|_VISIT_|_REGISTER_|_GENERATE_|_DETAIL_|_TIDY_|_MANAGE_PTR|_MATCHER|DEVICE_SHARED|_WORKAROUND_'
value: 'DEBUG|ASSERT|ASSUME|UNREACHABLE|FALLTHROUGH|STRINGIZE|_HAS_|_THROW|_REQUIRES|_DECLARE_|_VISIT_|_REGISTER_|_GENERATE_|_DETAIL_|_TIDY_|_MANAGE_PTR|_MATCHER|DEVICE_SHARED|_WORKAROUND_'
- key: modernize-loop-convert.MinConfidence
value: risky
- key: modernize-loop-convert.NamingStyle
......
......@@ -28,7 +28,7 @@ std::vector<char> src_compiler::compile(const std::vector<src_file>& srcs) const
{
params += " " + src.path.filename().string();
if(out.empty())
out = src.path.stem().string() + ".o";
out = src.path.stem().string() + out_ext;
}
}
......
......@@ -24,6 +24,7 @@ struct src_compiler
std::string flags = "";
std::string output = "";
std::string launcher = "";
std::string out_ext = ".o";
std::function<fs::path(fs::path)> process = nullptr;
std::vector<char> compile(const std::vector<src_file>& srcs) const;
};
......
......@@ -114,6 +114,7 @@ foreach(KERNEL_FILE ${KERNEL_FILES})
file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/kernels/include/migraphx/kernels/${KERNEL_BASE_FILE}.cpp "#include <migraphx/kernels/${KERNEL_BASE_FILE}.hpp>\n")
target_sources(kernel_file_check PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/kernels/include/migraphx/kernels/${KERNEL_BASE_FILE}.cpp)
endforeach()
target_compile_definitions(kernel_file_check PRIVATE -DMIGRAPHX_NLOCAL=256)
target_include_directories(kernel_file_check PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/>)
target_link_libraries(kernel_file_check compile_for_gpu)
......
......@@ -21,6 +21,7 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DEBUG);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_OPTIMIZE);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_ASM);
#if MIGRAPHX_USE_HIPRTC
......@@ -184,6 +185,13 @@ bool has_compiler_launcher()
return result;
}
src_compiler assemble(src_compiler compiler)
{
compiler.out_ext = ".S";
compiler.flags = replace_string(compiler.flags, " -c", " -S");
return compiler;
}
std::vector<std::vector<char>>
compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch)
{
......@@ -238,6 +246,12 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
MIGRAPHX_THROW("Missing hsaco");
};
if(enabled(MIGRAPHX_GPU_DUMP_ASM{}))
{
std::cout << assemble(compiler).compile(srcs).data() << std::endl;
}
return {compiler.compile(srcs)};
}
......
......@@ -17,7 +17,9 @@ struct run_op : action<run_op>
auto name = v.at("name").to<std::string>();
if(not contains(name, "::"))
name = "gpu::" + name;
auto op = make_op(name);
auto op = make_op(name);
if(v.contains("fields"))
op.from_value(v.at("fields"));
double t = time_op(ctx, op, inputs);
std::cout << op << ": " << t << "ms" << std::endl;
}
......
......@@ -31,6 +31,13 @@ struct hip_compile_options
void set_launch_params(const value& v,
const std::function<std::size_t(std::size_t local)>& compute_global,
std::size_t default_local = 1024);
void
set_launch_params(const value& v, std::size_t default_global, std::size_t default_local = 1024)
{
set_launch_params(
v, [=](auto) { return default_global; }, default_local);
}
};
/// Compute global for n elements, but max out on target-specific upper limit
......
......@@ -46,7 +46,7 @@ struct pointwise_compiler : compiler<pointwise_compiler>
if(std::any_of(inputs.begin(), inputs.end(), [](const auto& s) { return s.broadcasted(); }))
return 1;
else
return 4;
return 256;
}
static std::size_t vectorize_elements(const std::vector<shape>& inputs)
{
......
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
static const char* const simple_reduce_kernel = R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/reduce.hpp>
#include <args.hpp>
namespace migraphx {
${preamble}
extern "C" {
__global__ void kernel(void* input_p, void* output_p)
{
make_tensors()(input_p, output_p)([](auto input, auto output) {
simple_reduce(${reduction}, ${init}, input, output, ${read}, ${write});
});
}
}
} // namespace migraphx
)__migraphx__";
constexpr std::size_t compute_block_size(std::size_t n, std::size_t max_block_size = 1024)
{
size_t block_size = 128;
while(block_size <= max_block_size and block_size <= n)
block_size *= 2;
return block_size / 2;
}
static std::size_t get_reduce_elements(const std::vector<shape>& inputs)
{
return inputs.front().elements() / inputs.back().elements();
}
static std::size_t get_reduce_elements(const std::vector<instruction_ref>& inputs)
{
return get_reduce_elements(to_shapes(inputs));
}
struct reduce_compiler : compiler<reduce_compiler>
{
std::vector<std::string> names() const
{
return {"reduce", "reduce_sum", "reduce_mean", "reduce_max", "reduce_min", "reduce_prod"};
}
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
hip_compile_options options;
auto reduce_elements = get_reduce_elements(inputs);
auto block_size = compute_block_size(reduce_elements, 256);
options.set_launch_params(
v, compute_global_for(ctx, inputs.back().elements() * block_size, 256), block_size);
options.inputs = inputs;
options.output = inputs.back();
options.virtual_inputs = reduce_dims(inputs);
options.params = "-Wno-float-equal";
std::string identity = "[](auto x) { return x; }";
auto src = interpolate_string(simple_reduce_kernel,
{{"reduction", v.at("reduction").to<std::string>()},
{"init", v.get("init", std::string{"0"})},
{"read", v.get("read", identity)},
{"write", v.get("write", identity)},
{"preamble", v.get("preamble", std::string{})}});
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
value v = value::object{};
auto reduce_elements = get_reduce_elements(ins->inputs());
if(op.name() == "reduce_sum")
{
v["reduction"] = "op::sum{}";
}
else if(op.name() == "reduce_mean")
{
v["reduction"] = "op::sum{}";
v["write"] = "op::mean{" + std::to_string(reduce_elements) + "}";
}
else if(op.name() == "reduce_max")
{
v["reduction"] = "op::max{}";
v["init"] = "lowest{}";
}
else if(op.name() == "reduce_min")
{
v["reduction"] = "op::min{}";
v["init"] = "highest{}";
}
else if(op.name() == "reduce_prod")
{
v["reduction"] = "op::product{}";
v["init"] = "1";
}
else
{
MIGRAPHX_THROW("Unsupported reduce");
}
return replace(compile_op(ctx, to_shapes(ins->inputs()), v));
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -106,6 +106,35 @@ constexpr Iterator1 search(Iterator1 first, Iterator1 last, Iterator2 s_first, I
}
}
template <class InputIt1, class InputIt2, class T, class BinaryOperation1, class BinaryOperation2>
constexpr T inner_product(InputIt1 first1,
InputIt1 last1,
InputIt2 first2,
T init,
BinaryOperation1 op1,
BinaryOperation2 op2)
{
while(first1 != last1)
{
init = op1(init, op2(*first1, *first2));
++first1;
++first2;
}
return init;
}
template <class InputIt1, class InputIt2, class T>
constexpr T inner_product(InputIt1 first1, InputIt1 last1, InputIt2 first2, T init)
{
return inner_product(
first1,
last1,
first2,
init,
[](auto x, auto y) { return x + y; },
[](auto x, auto y) { return x * y; });
}
} // namespace migraphx
#endif
......@@ -74,6 +74,7 @@ struct array
constexpr const T* data() const { return d; }
constexpr index_constant<N> size() const { return {}; }
constexpr auto empty() const { return size() == _c<0>; }
constexpr T* begin() { return d; }
constexpr const T* begin() const { return d; }
......
......@@ -73,12 +73,19 @@ assert_fail(const T1& assertion, const T2& file, const T3& line, const T4& funct
abort();
}
#ifdef MIGRAPHX_DEBUG
#define MIGRAPHX_ASSERT(cond) \
// NOLINTNEXTLINE
#define MIGRAPHX_CHECK(cond) \
((cond) ? void(0) : [](auto&&... private_migraphx_xs) { \
assert_fail(private_migraphx_xs...); \
}(#cond, __FILE__, MIGRAPHX_STRINGIZE(__LINE__), __PRETTY_FUNCTION__))
#ifdef MIGRAPHX_DEBUG
#define MIGRAPHX_ASSERT MIGRAPHX_CHECK
#define MIGRAPHX_ASSUME MIGRAPHX_CHECK
#define MIGRAPHX_UNREACHABLE() MIGRAPHX_ASSERT(false)
#else
#define MIGRAPHX_ASSUME __builtin_assume
#define MIGRAPHX_UNREACHABLE __builtin_unreachable
#define MIGRAPHX_ASSERT(cond)
#endif
......
#ifndef MIGRAPHX_GUARD_KERNELS_DPP_HPP
#define MIGRAPHX_GUARD_KERNELS_DPP_HPP
#include <migraphx/kernels/hip.hpp>
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/debug.hpp>
namespace migraphx {
#ifndef MIGRAPHX_HAS_DPP
#define MIGRAPHX_HAS_DPP 1
#endif
#if MIGRAPHX_HAS_DPP
constexpr unsigned int dpp_row_shr(unsigned int x) { return 0x110u | x; }
constexpr unsigned int dpp_row_bcast(unsigned int x)
{
unsigned int y = 0;
switch(x)
{
case 15: y = 0x142; break;
case 31: y = 0x143; break;
default: MIGRAPHX_UNREACHABLE();
}
return y;
}
template <unsigned int DppCtrl,
unsigned int RowMask = 0xf,
unsigned int BankMask = 0xf,
bool BoundCtrl = false,
class T>
__device__ T dpp_mov(T& x)
{
static const index_int n = sizeof(T) < 4 ? 1 : sizeof(T) / 4;
union type
{
uint32_t reg[n];
T data;
};
type output{};
type input{};
// cppcheck-suppress unreadVariable
input.data = x;
for(index_int i = 0; i < n; i++)
{
output.reg[i] = __hip_move_dpp(input.reg[i], DppCtrl, RowMask, BankMask, BoundCtrl);
}
return output.data;
}
#endif
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_DPP_HPP
......@@ -3,6 +3,7 @@
#include <migraphx/kernels/hip.hpp>
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp>
namespace migraphx {
......@@ -12,23 +13,23 @@ struct index
index_int local = 0;
index_int group = 0;
__device__ index_int nglobal() const
{
#ifdef MIGRAPHX_NGLOBAL
return MIGRAPHX_NGLOBAL;
constexpr index_constant<MIGRAPHX_NGLOBAL> nglobal() const { return {}; }
#else
__device__ index_int nglobal() const
{
return blockDim.x * gridDim.x; // NOLINT
#endif
}
#endif
__device__ index_int nlocal() const
{
#ifdef MIGRAPHX_NLOCAL
return MIGRAPHX_NLOCAL;
constexpr index_constant<MIGRAPHX_NLOCAL> nlocal() const { return {}; }
#else
return blockDim.x; // NOLINT
#endif
__device__ index_int nlocal() const
{
return blockDim.x; // NOLINT
}
#endif
template <class F>
__device__ void global_stride(index_int n, F f) const
......
......@@ -40,6 +40,10 @@ constexpr T as_float(T x)
return fname(x, xs...); \
}
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_BINARY_FOR(type, name, fname) \
inline auto __device__ name(type x, type y)->type { return fname(x, y); }
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_HALF(name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
......@@ -154,6 +158,38 @@ constexpr auto where(bool cond, const T& a, const U& b)
return cond ? a : b;
}
MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, max, ::max)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, min, ::min)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(double, max, ::max)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(double, min, ::min)
// Add overloads for half that calls the float version
MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, max, ::fmaxf)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, min, ::fminf)
template <class T, MIGRAPHX_REQUIRES(not is_any_vec<T>())>
constexpr auto max(const T& a, const T& b)
{
return where(a < b, b, a);
}
template <class T, MIGRAPHX_REQUIRES(not is_any_vec<T>())>
constexpr auto min(const T& a, const T& b)
{
return where(a < b, a, b);
}
template <class T, class U, MIGRAPHX_REQUIRES(not is_same<T, U>{} and not is_any_vec<T, U>())>
constexpr auto max(const T& a, const U& b)
{
return max<common_type_t<T, U>>(a, b);
}
template <class T, class U, MIGRAPHX_REQUIRES(not is_same<T, U>{} and not is_any_vec<T, U>())>
constexpr auto min(const T& a, const U& b)
{
return min<common_type_t<T, U>>(a, b);
}
MIGRAPHX_DEVICE_MATH_VEC(abs)
MIGRAPHX_DEVICE_MATH_VEC(acos)
MIGRAPHX_DEVICE_MATH_VEC(acosh)
......@@ -169,6 +205,8 @@ MIGRAPHX_DEVICE_MATH_VEC(exp)
MIGRAPHX_DEVICE_MATH_VEC(floor)
MIGRAPHX_DEVICE_MATH_VEC(isnan)
MIGRAPHX_DEVICE_MATH_VEC(log)
MIGRAPHX_DEVICE_MATH_VEC(max)
MIGRAPHX_DEVICE_MATH_VEC(min)
MIGRAPHX_DEVICE_MATH_VEC(pow)
MIGRAPHX_DEVICE_MATH_VEC(round)
MIGRAPHX_DEVICE_MATH_VEC(rsqrt)
......@@ -179,18 +217,6 @@ MIGRAPHX_DEVICE_MATH_VEC(tan)
MIGRAPHX_DEVICE_MATH_VEC(tanh)
MIGRAPHX_DEVICE_MATH_VEC(where)
template <class T, class U>
constexpr auto max(const T& a, const U& b)
{
return where(a < b, b, a);
}
template <class T, class U>
constexpr auto min(const T& a, const U& b)
{
return where(a > b, b, a);
}
template <class T, class U>
constexpr auto convert(U v)
{
......
#ifndef MIGRAPHX_GUARD_KERNELS_OPS_HPP
#define MIGRAPHX_GUARD_KERNELS_OPS_HPP
#include <migraphx/kernels/math.hpp>
namespace migraphx {
namespace op {
struct sum
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x, U y) const
{
return x + y;
}
};
struct product
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x, U y) const
{
return x * y;
}
};
struct id
{
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x) const
{
return x;
}
};
struct mean
{
index_int item_num = 1;
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x) const
{
return x / static_cast<T>(item_num);
}
};
struct max
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x, U y) const
{
return migraphx::max(x, y);
}
};
struct min
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x, U y) const
{
return migraphx::min(x, y);
}
};
} // namespace op
struct lowest
{
template <class T>
constexpr operator T() const
{
return numeric_lowest<T>();
}
};
struct highest
{
template <class T>
constexpr operator T() const
{
return numeric_max<T>();
}
};
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_OPS_HPP
......@@ -140,6 +140,10 @@ struct basic_printer
{
return print_ulong(value);
}
__host__ __device__ const basic_printer& operator<<(migraphx::half value) const
{
return print_double(value);
}
__host__ __device__ const basic_printer& operator<<(float value) const
{
return print_double(value);
......
#ifndef MIGRAPHX_GUARD_KERNELS_REDUCE_HPP
#define MIGRAPHX_GUARD_KERNELS_REDUCE_HPP
#include <migraphx/kernels/dpp.hpp>
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/ops.hpp>
namespace migraphx {
#if MIGRAPHX_HAS_DPP
template <class T, class Op>
__device__ void dpp_reduce(T& in, Op op)
{
T out{};
out = dpp_mov<dpp_row_shr(1)>(in);
in = op(in, out);
out = dpp_mov<dpp_row_shr(2)>(in);
in = op(in, out);
out = dpp_mov<dpp_row_shr(4), 0xf, 0xe>(in);
in = op(in, out);
out = dpp_mov<dpp_row_shr(8), 0xf, 0xc>(in);
in = op(in, out);
#if __AMDGCN_WAVEFRONT_SIZE == 64
out = dpp_mov<dpp_row_bcast(15), 0xa>(in);
in = op(in, out);
out = dpp_mov<dpp_row_bcast(31), 0xc>(in);
in = op(in, out);
#endif
}
#if defined(MIGRAPHX_USE_CLANG_TIDY) || defined(CPPCHECK)
// NOLINTNEXTLINE
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins) x = 1
#elif __AMDGCN_WAVEFRONT_SIZE == 64
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins) \
__asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:8 bank_mask:0xc\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_bcast:15 row_mask:0xa\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_bcast:31 row_mask:0xc\n" \
"s_nop 1\n" \
: "=v"(x) \
: "0"(x))
#else
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins) \
__asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:8 bank_mask:0xc\n" \
"s_nop 1\n" \
"s_nop 1\n" \
: "=v"(x) \
: "0"(x))
#endif
// NOLINTNEXTLINE
#define MIGRAPHX_DPP_REDUCE(op, prefix) \
__device__ inline void dpp_reduce(double& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f64); } \
__device__ inline void dpp_reduce(float& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f32); } \
__device__ inline void dpp_reduce(half& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f16); } \
__device__ inline void dpp_reduce(int32_t& x, op) \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32); \
} \
__device__ inline void dpp_reduce(uint32_t& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32); }
MIGRAPHX_DPP_REDUCE(op::sum, v_add)
MIGRAPHX_DPP_REDUCE(op::max, v_max)
MIGRAPHX_DPP_REDUCE(op::min, v_min)
MIGRAPHX_DPP_REDUCE(op::product, v_mul)
template <class Op, class T, class F>
__device__ auto block_reduce(index idx, Op op, T init, index_int n, F f)
{
#if __AMDGCN_WAVEFRONT_SIZE == 32
constexpr index_int lanes_per_thread = 16;
#else
constexpr index_int lanes_per_thread = 64;
#endif
using type = decltype(f(0));
__shared__ type buffer[idx.nlocal() / lanes_per_thread];
type x = init;
idx.local_stride(n, [&](auto i) { x = op(x, f(i)); });
dpp_reduce(x, op);
const auto ldsidx = idx.local / lanes_per_thread;
if((idx.local % lanes_per_thread) == lanes_per_thread - 1)
{
buffer[ldsidx] = x;
}
__syncthreads();
type y = init;
for(index_int i = 0; i < idx.nlocal() / lanes_per_thread; i++)
{
y = op(y, buffer[i]);
}
return y;
}
#else
template <class Op, class T, class F>
__device__ auto block_reduce(index idx, Op op, T init, index_int n, F f)
{
using type = decltype(f(0));
__shared__ type buffer[idx.nlocal()];
type x = init;
idx.local_stride(n, [&](auto i) { x = op(x, f(i)); });
buffer[idx.local] = x;
__syncthreads();
for(index_int s = 1; s < idx.nlocal(); s *= 2)
{
const index_int index = 2 * s * idx.local;
if(index + s < idx.nlocal())
{
buffer[index] = op(buffer[index], buffer[index + s]);
}
__syncthreads();
}
return buffer[0];
}
#endif
template <class Input, class T, class Output>
constexpr auto reduce_slice(Input input, T i, Output)
{
constexpr auto lens = transform(get_shape_c<Input>{}.lens,
get_shape_c<Output>{}.lens,
[](index_int x, index_int y) -> index_int {
if(x == y)
return 1;
return x;
});
;
constexpr auto s = make_shape(lens, get_shape_c<Input>{}.strides);
return make_tensor_view(&input[i], s);
}
template <class Op, class T, class Input, class Output, class ReadInput, class WriteOuput>
__device__ void
simple_reduce(Op op, T init, Input input, Output output, ReadInput read, WriteOuput write)
{
auto idx = make_index();
constexpr auto nelements = get_shape_c<Output>{}.elements();
constexpr auto relements = get_shape_c<Input>{}.elements() / get_shape_c<Output>{}.elements();
idx.global_stride(nelements * idx.nlocal(), [&](auto i) {
const auto out_idx = output.get_shape().multi(i / idx.nlocal());
auto rs = reduce_slice(input, out_idx, output);
MIGRAPHX_ASSERT(relements == rs.get_shape().elements());
auto r = block_reduce(idx, op, init, relements, [&](auto j) { return read(rs[j]); });
if(idx.local == 0)
output[out_idx] = write(r);
});
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_REDUCE_HPP
......@@ -83,11 +83,12 @@ struct shape
}
}
/// Convert single index into a multi-index
constexpr index_array multi(index_int idx) const
{
index_array result;
index_int tidx = idx;
for(std::ptrdiff_t is = result.size() - 1; is > 0; is--)
for(diff_int is = result.size() - 1; is > 0; is--)
{
result[is] = tidx % lens[is];
tidx = tidx / lens[is];
......@@ -95,6 +96,13 @@ struct shape
result[0] = tidx;
return result;
}
/// Convert multi-index into a single index
constexpr index_int single(index_array idx) const
{
if(idx.empty())
return 0;
return inner_product(lens.begin() + 1, lens.end(), idx.begin(), idx.back());
}
constexpr shape get_shape() const { return *this; }
......
......@@ -21,9 +21,10 @@ struct tensor_view_iterator_read
template <class T, class Shape>
struct tensor_view
{
using type = T;
using shape_type = Shape;
using iterator = basic_iota_iterator<tensor_view_iterator_read<const tensor_view>, index_int>;
using type = T;
using shape_type = Shape;
using index_array = typename Shape::index_array;
using iterator = basic_iota_iterator<tensor_view_iterator_read<const tensor_view>, index_int>;
constexpr Shape get_shape() const { return Shape{}; }
constexpr auto size() const { return get_shape().elements(); }
......@@ -40,6 +41,13 @@ struct tensor_view
constexpr auto begin() const { return iterator{0, {this}}; }
constexpr auto end() const { return iterator{this->size(), {this}}; }
constexpr auto begin_at(index_array i) const
{
MIGRAPHX_ASSERT(get_shape().single(i) < get_shape().elements());
MIGRAPHX_ASSERT(get_shape().index(i) < get_shape().element_space());
return iterator{get_shape().single(i), {this}};
}
template <class U>
constexpr tensor_view<U, Shape> with(U* y) const
{
......@@ -50,6 +58,9 @@ struct tensor_view
T* x;
};
template <class T>
using get_shape_c = typename T::shape_type;
template <class T, class Shape>
constexpr tensor_view<T, Shape> make_tensor_view(T* x, Shape)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment