Unverified Commit 4c72cc95 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Add lane reduction (#1180)

With reductions such as {2048, 2, 1456} on axes 1, this is 23x faster than using our new block_reduce, and its even over 100x faster than our original reduce_sum:

# lane
gpu::code_object[code_object=13736,symbol_name=kernel,global=2981888,local=1024,]: 0.0672928ms
# block
gpu::code_object[code_object=13800,symbol_name=kernel,global=39321600,local=64,]: 1.46072ms
# original
gpu::reduce_sum[axes={1}]: 6.73456ms
There is some basic logic to pick between lane and block reduce automatically.
parent 36656030
...@@ -20,7 +20,7 @@ struct run_op : action<run_op> ...@@ -20,7 +20,7 @@ struct run_op : action<run_op>
auto op = make_op(name); auto op = make_op(name);
if(v.contains("fields")) if(v.contains("fields"))
op.from_value(v.at("fields")); op.from_value(v.at("fields"));
double t = time_op(ctx, op, inputs); double t = time_op(ctx, op, inputs, p.get(v, "iterations", 100));
std::cout << op << ": " << t << "ms" << std::endl; std::cout << op << ": " << t << "ms" << std::endl;
} }
}; };
......
...@@ -30,7 +30,7 @@ __global__ void kernel(void* input_p, void* output_p) ...@@ -30,7 +30,7 @@ __global__ void kernel(void* input_p, void* output_p)
{ {
make_tensors()(input_p, output_p)([](auto input, auto output) { make_tensors()(input_p, output_p)([](auto input, auto output) {
simple_reduce(${reduction}, ${init}, input, output, ${read}, ${write}); simple_reduce<reduce::${algo}>(${reduction}, ${init}, input, output, ${read}, ${write});
}); });
} }
...@@ -57,6 +57,40 @@ static std::size_t get_reduce_elements(const std::vector<instruction_ref>& input ...@@ -57,6 +57,40 @@ static std::size_t get_reduce_elements(const std::vector<instruction_ref>& input
return get_reduce_elements(to_shapes(inputs)); return get_reduce_elements(to_shapes(inputs));
} }
static std::vector<std::size_t> get_reduce_lens(const std::vector<std::size_t>& input_lens,
const std::vector<std::size_t>& output_lens)
{
std::vector<std::size_t> reduce_lens;
std::transform(output_lens.begin(),
output_lens.end(),
input_lens.begin(),
std::back_inserter(reduce_lens),
[](auto x, auto y) -> std::size_t {
if(x == y)
return 1;
else
return y;
});
return reduce_lens;
}
static std::string get_reduce_algo(const std::vector<shape>& inputs)
{
auto rlens = get_reduce_lens(inputs.front().lens(), inputs.back().lens());
const auto init = std::numeric_limits<std::size_t>::max();
// The minimum stride
auto min_stride = std::inner_product(
rlens.begin(),
rlens.end(),
inputs.front().strides().begin(),
init,
[](auto x, auto y) { return std::min(x, y); },
[](auto len, auto stride) { return len == 1 ? init : stride; });
if(min_stride > 2)
return "lane";
return "block";
}
struct reduce_compiler : compiler<reduce_compiler> struct reduce_compiler : compiler<reduce_compiler>
{ {
std::vector<std::string> names() const std::vector<std::string> names() const
...@@ -68,20 +102,33 @@ struct reduce_compiler : compiler<reduce_compiler> ...@@ -68,20 +102,33 @@ struct reduce_compiler : compiler<reduce_compiler>
{ {
hip_compile_options options; hip_compile_options options;
auto reduce_elements = get_reduce_elements(inputs); auto reduce_elements = get_reduce_elements(inputs);
auto algo = v.get("algo", get_reduce_algo(inputs));
if(algo == "block")
{
auto block_size = compute_block_size(reduce_elements, 256); auto block_size = compute_block_size(reduce_elements, 256);
options.set_launch_params( options.set_launch_params(
v, compute_global_for(ctx, inputs.back().elements() * block_size, 256), block_size); v, compute_global_for(ctx, inputs.back().elements() * block_size, 256), block_size);
}
else if(algo == "lane")
{
options.set_launch_params(v, compute_global_for(ctx, inputs.back().elements(), 256));
}
else
{
MIGRAPHX_THROW("Unknown reduce algo: " + algo);
}
options.inputs = inputs; options.inputs = inputs;
options.output = inputs.back(); options.output = inputs.back();
options.virtual_inputs = reduce_dims(inputs); options.virtual_inputs = reduce_dims(inputs);
options.params = "-Wno-float-equal";
std::string identity = "[](auto x) { return x; }"; std::string identity = "[](auto x) { return x; }";
auto src = interpolate_string(simple_reduce_kernel, auto src = interpolate_string(simple_reduce_kernel,
{{"reduction", v.at("reduction").to<std::string>()}, {{"reduction", v.at("reduction").to<std::string>()},
{"init", v.get("init", std::string{"0"})}, {"init", v.get("init", std::string{"0"})},
{"read", v.get("read", identity)}, {"read", v.get("read", identity)},
{"write", v.get("write", identity)}, {"write", v.get("write", identity)},
{"algo", algo},
{"preamble", v.get("preamble", std::string{})}}); {"preamble", v.get("preamble", std::string{})}});
options.params += "-Wno-float-equal";
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
......
...@@ -42,6 +42,32 @@ struct print_buffer ...@@ -42,6 +42,32 @@ struct print_buffer
pos++; pos++;
} }
} }
template <class T, class = decltype(T{} % 10, -T{})>
constexpr void append(T i)
{
if(i < 0)
{
append('-');
i = -i;
}
char c = (i % 10) + '0';
if(i > 9)
append(i / 10);
append(c);
}
constexpr void append(const char* str)
{
if(str == nullptr)
return;
int i = 512;
while(*str != 0 and i > 0)
{
append(*str);
str++;
i--;
}
}
template <size_t M> template <size_t M>
constexpr void append(const char (&array)[M]) constexpr void append(const char (&array)[M])
...@@ -54,14 +80,36 @@ struct print_buffer ...@@ -54,14 +80,36 @@ struct print_buffer
template <class... Ts> template <class... Ts>
__host__ __device__ void print(const Ts&... xs) __host__ __device__ void print(const Ts&... xs)
{ {
const auto size = (sizeof(xs) + ...); print_buffer<1024> buffer;
print_buffer<size> buffer;
swallow{(buffer.append(xs), 0)...}; swallow{(buffer.append(xs), 0)...};
printf("%s", buffer.buffer); printf("%s", buffer.buffer);
} }
} // namespace debug } // namespace debug
struct source_location
{
int line = __builtin_LINE();
const char* file = __builtin_FILE();
const char* function = __builtin_FUNCTION();
};
template <class T>
struct source_location_capture
{
T x;
source_location loc;
template <class U, class = decltype(T(U{}))>
constexpr source_location_capture(U px, source_location ploc = source_location{})
: x(px), loc(ploc)
{
}
constexpr operator source_location() const { return loc; }
constexpr operator T() const { return x; }
};
// noreturn cannot be used on this function because abort in hip is broken // noreturn cannot be used on this function because abort in hip is broken
template <class T1, class T2, class T3, class T4> template <class T1, class T2, class T3, class T4>
MIGRAPHX_HIP_NORETURN inline __host__ __device__ void MIGRAPHX_HIP_NORETURN inline __host__ __device__ void
...@@ -73,20 +121,38 @@ assert_fail(const T1& assertion, const T2& file, const T3& line, const T4& funct ...@@ -73,20 +121,38 @@ assert_fail(const T1& assertion, const T2& file, const T3& line, const T4& funct
abort(); abort();
} }
template <class... Ts>
MIGRAPHX_HIP_NORETURN inline __host__ __device__ void assert_fail(const source_location& loc,
Ts... xs)
{
debug::print(loc.file, ":", loc.line, ": ", loc.function, ": error: ", xs..., "\n");
abort();
}
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_CHECK(cond) \ #define MIGRAPHX_ASSERT_FAIL(cond, ...) \
((cond) ? void(0) : [](auto&&... private_migraphx_xs) { \ ((cond) ? void(0) : [](auto&&... private_migraphx_xs) { \
assert_fail(private_migraphx_xs...); \ assert_fail(private_migraphx_xs...); \
}(#cond, __FILE__, MIGRAPHX_STRINGIZE(__LINE__), __PRETTY_FUNCTION__)) }(__VA_ARGS__))
// NOLINTNEXTLINE
#define MIGRAPHX_CHECK(cond) \
MIGRAPHX_ASSERT_FAIL(cond, #cond, __FILE__, __LINE__, __PRETTY_FUNCTION__)
#ifdef MIGRAPHX_DEBUG #ifdef MIGRAPHX_DEBUG
// NOLINTNEXTLINE
#define MIGRAPHX_CAPTURE_SOURCE_LOCATION(T) source_location_capture<T>
#define MIGRAPHX_WARN(cond, loc, ...) MIGRAPHX_ASSERT_FAIL(cond, loc, __VA_ARGS__)
#define MIGRAPHX_ASSERT MIGRAPHX_CHECK #define MIGRAPHX_ASSERT MIGRAPHX_CHECK
#define MIGRAPHX_ASSUME MIGRAPHX_CHECK #define MIGRAPHX_ASSUME MIGRAPHX_CHECK
#define MIGRAPHX_UNREACHABLE() MIGRAPHX_ASSERT(false) #define MIGRAPHX_UNREACHABLE() MIGRAPHX_ASSERT(false)
#else #else
// NOLINTNEXTLINE
#define MIGRAPHX_CAPTURE_SOURCE_LOCATION(T) T
#define MIGRAPHX_ASSUME __builtin_assume #define MIGRAPHX_ASSUME __builtin_assume
#define MIGRAPHX_UNREACHABLE __builtin_unreachable #define MIGRAPHX_UNREACHABLE __builtin_unreachable
#define MIGRAPHX_ASSERT(cond) #define MIGRAPHX_ASSERT(cond)
#define MIGRAPHX_WARN(...)
#endif #endif
} // namespace migraphx } // namespace migraphx
......
...@@ -124,8 +124,8 @@ __device__ auto block_reduce(index idx, Op op, T init, index_int n, F f) ...@@ -124,8 +124,8 @@ __device__ auto block_reduce(index idx, Op op, T init, index_int n, F f)
} }
#endif #endif
template <class Input, class T, class Output> template <class Output, class Input, class T>
constexpr auto reduce_slice(Input input, T i, Output) constexpr auto reduce_slice(Input input, T i)
{ {
constexpr auto lens = transform(get_shape_c<Input>{}.lens, constexpr auto lens = transform(get_shape_c<Input>{}.lens,
get_shape_c<Output>{}.lens, get_shape_c<Output>{}.lens,
...@@ -136,23 +136,126 @@ constexpr auto reduce_slice(Input input, T i, Output) ...@@ -136,23 +136,126 @@ constexpr auto reduce_slice(Input input, T i, Output)
}); });
; ;
constexpr auto s = make_shape(lens, get_shape_c<Input>{}.strides); constexpr auto s = make_shape(lens, get_shape_c<Input>{}.strides);
MIGRAPHX_ASSERT((input.get_shape().index(i) + s.element_space()) <=
input.get_shape().element_space());
return make_tensor_view(&input[i], s); return make_tensor_view(&input[i], s);
} }
template <class Op, class T, class Input, class Output, class ReadInput, class WriteOuput> namespace reduce {
__device__ void
simple_reduce(Op op, T init, Input input, Output output, ReadInput read, WriteOuput write) template <class Slicer, class F>
constexpr auto sliced(Slicer slicer, F f)
{
return [=](auto x, auto... xs) {
// TODO: assert all elements are the same
return f(slicer(x), slicer(xs)...);
};
}
struct block
{ {
template <class Slicer>
struct reducer
{
index idx;
Slicer slicer;
template <class Op, class T, class Read>
__device__ auto reduce(Op op, T init, Read read) const
{
return sliced(slicer, [=](auto x, auto... xs) {
return block_reduce(idx, op, init, x.get_shape().elements(), [&](auto j) {
return read(x[j], xs[j]...);
});
});
}
template <class F>
__device__ void outer(F f) const
{
if(idx.local == 0)
f();
}
};
template <class Slicer>
static __device__ auto make(index idx, Slicer slicer)
{
return reducer<Slicer>{idx, slicer};
}
template <class Output, class F>
static __device__ void run(F f)
{
auto idx = make_index(); auto idx = make_index();
constexpr auto nelements = get_shape_c<Output>{}.elements(); constexpr auto nelements = get_shape_c<Output>{}.elements();
constexpr auto relements = get_shape_c<Input>{}.elements() / get_shape_c<Output>{}.elements();
idx.global_stride(nelements * idx.nlocal(), [&](auto i) { idx.global_stride(nelements * idx.nlocal(), [&](auto i) {
const auto out_idx = output.get_shape().multi(i / idx.nlocal()); const auto out_idx = get_shape_c<Output>{}.multi(i / idx.nlocal());
auto rs = reduce_slice(input, out_idx, output); f(out_idx, make(idx, [&](auto input) { return reduce_slice<Output>(input, out_idx); }));
MIGRAPHX_ASSERT(relements == rs.get_shape().elements()); });
auto r = block_reduce(idx, op, init, relements, [&](auto j) { return read(rs[j]); }); }
if(idx.local == 0) };
output[out_idx] = write(r);
struct lane
{
template <class Slicer>
struct reducer
{
index idx;
Slicer slicer;
template <class Op, class T, class Read>
__device__ auto reduce(Op op, T init, Read read) const
{
return sliced(slicer, [=](auto x, auto... xs) {
using type = typename decltype(x)::type;
type r = init;
for(index_int j = 0; j < x.get_shape().elements(); j++)
{
r = op(r, read(x[j], xs[j]...));
}
return r;
});
}
template <class F>
__device__ void outer(F f) const
{
f();
}
};
template <class Slicer>
static __device__ auto make(index idx, Slicer slicer)
{
return reducer<Slicer>{idx, slicer};
}
template <class Output, class F>
static __device__ void run(F f)
{
auto idx = make_index();
constexpr auto nelements = get_shape_c<Output>{}.elements();
idx.global_stride(nelements, [&](auto i) {
const auto out_idx = get_shape_c<Output>{}.multi(i);
f(out_idx, make(idx, [&](auto input) { return reduce_slice<Output>(input, out_idx); }));
});
}
};
} // namespace reduce
template <class Algo,
class Op,
class T,
class Input,
class Output,
class ReadInput,
class WriteOuput>
__device__ void
simple_reduce(Op op, T init, Input input, Output output, ReadInput read, WriteOuput write)
{
Algo::template run<Output>([&](auto out_idx, auto r) {
auto x = r.reduce(op, init, read)(input);
r.outer([&] { output[out_idx] = write(x); });
}); });
} }
......
...@@ -29,11 +29,23 @@ struct tensor_view ...@@ -29,11 +29,23 @@ struct tensor_view
constexpr Shape get_shape() const { return Shape{}; } constexpr Shape get_shape() const { return Shape{}; }
constexpr auto size() const { return get_shape().elements(); } constexpr auto size() const { return get_shape().elements(); }
struct index_to_offset
{
index_int offset;
template <class U> template <class U>
constexpr T& operator[](U i) const constexpr index_to_offset(U i) : offset(Shape{}.index(i))
{ {
MIGRAPHX_ASSERT(get_shape().index(i) < get_shape().element_space()); }
return x[get_shape().index(i)]; };
constexpr T& operator[](MIGRAPHX_CAPTURE_SOURCE_LOCATION(index_to_offset) i) const
{
index_to_offset ito = i;
MIGRAPHX_WARN(ito.offset < get_shape().element_space(),
i,
"Out of bounds access at offset: ",
ito.offset);
return x[ito.offset];
} }
constexpr T* data() const { return x; } constexpr T* data() const { return x; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment