Unverified Commit f9a5b81e authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Reduce with runtime compilation (#1150)

There is significant improvement on larger tensors with half almost 50% faster:

lens: [1024, 384, 768]
gpu::code_object[code_object=13832,symbol_name=kernel,global=39321600,local=256,]: 1.16685ms
gpu::reduce_sum[axes={2}]: 1.73126ms
Also for non-trivial layouts this can sometimes be over 2x faster:

lens: [64, 1024, 768, 4]
gpu::code_object[code_object=13832,symbol_name=kernel,global=39321600,local=256,]: 1.1706ms
gpu::reduce_sum[axes={1}]: 2.63375ms
Of course if the stride becomes larger this speed improvement diminishes due to poor memory access patterns. A lane_reduce instead of a block_reduce is needed for such type of kernels. I plan to address that in a future PR.

Finally, this also includes a MIGRAPHX_GPU_DUMP_ASM env variable which will print out the assembly when the kernel compiles.
parent 12007dba
......@@ -6,6 +6,15 @@
namespace migraphx {
template <class T, class U = T&&>
U private_declval(int);
template <class T>
T private_declval(long);
template <class T>
auto declval() noexcept -> decltype(private_declval<T>(0));
template <class T>
struct type_identity
{
......@@ -26,20 +35,73 @@ struct enable_if<true, T>
template <bool B, class T = void>
using enable_if_t = typename enable_if<B, T>::type;
template <class From, class To>
struct is_convertible : bool_constant<__is_convertible(From, To)>
{
};
template <class T, class U>
struct is_same : false_type
{
};
template <class T>
struct is_same<T, T> : true_type
{
};
// NOLINTNEXTLINE
#define MIGRAPHX_BUILTIN_TYPE_TRAIT1(name) \
template <class T> \
struct name : bool_constant<__##name(T)> \
{ \
}
// NOLINTNEXTLINE
#define MIGRAPHX_BUILTIN_TYPE_TRAIT2(name) \
template <class T, class U> \
struct name : bool_constant<__##name(T, U)> \
{ \
}
// NOLINTNEXTLINE
#define MIGRAPHX_BUILTIN_TYPE_TRAITN(name) \
template <class... Ts> \
struct name : bool_constant<__##name(Ts...)> \
{ \
}
// MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_arithmetic);
// MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_destructible);
// MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_nothrow_destructible);
// MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_pointer);
// MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_scalar);
// MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_signed);
// MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_void);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_abstract);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_aggregate);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_array);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_class);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_compound);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_const);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_empty);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_enum);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_final);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_floating_point);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_function);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_fundamental);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_integral);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_literal_type);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_lvalue_reference);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_member_function_pointer);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_member_object_pointer);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_member_pointer);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_object);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_pod);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_polymorphic);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_reference);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_rvalue_reference);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_standard_layout);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_trivial);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_trivially_copyable);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_trivially_destructible);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_union);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_unsigned);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_volatile);
MIGRAPHX_BUILTIN_TYPE_TRAIT2(is_assignable);
MIGRAPHX_BUILTIN_TYPE_TRAIT2(is_base_of);
MIGRAPHX_BUILTIN_TYPE_TRAIT2(is_convertible);
MIGRAPHX_BUILTIN_TYPE_TRAIT2(is_nothrow_assignable);
MIGRAPHX_BUILTIN_TYPE_TRAIT2(is_same);
MIGRAPHX_BUILTIN_TYPE_TRAIT2(is_trivially_assignable);
MIGRAPHX_BUILTIN_TYPE_TRAITN(is_constructible);
MIGRAPHX_BUILTIN_TYPE_TRAITN(is_nothrow_constructible);
MIGRAPHX_BUILTIN_TYPE_TRAITN(is_trivially_constructible);
template <class T>
struct remove_reference
......@@ -68,6 +130,68 @@ struct add_pointer : type_identity<typename remove_reference<T>::type*>
template <class T>
using add_pointer_t = typename add_pointer<T>::type;
template <class... Ts>
struct common_type;
template <class T>
struct common_type<T>
{
using type = T;
};
template <class T, class U>
struct common_type<T, U>
{
using type = decltype(true ? declval<T>() : declval<U>());
};
template <class T, class U, class... Us>
struct common_type<T, U, Us...>
{
using type = typename common_type<typename common_type<T, U>::type, Us...>::type;
};
template <class... Ts>
using common_type_t = typename common_type<Ts...>::type;
constexpr unsigned long int_max(unsigned long n) { return (1u << (n * 8)) - 1; }
template <class T>
constexpr T numeric_max()
{
if constexpr(is_integral<T>{})
{
if constexpr(is_unsigned<T>{})
return int_max(sizeof(T)) * 2;
else
return int_max(sizeof(T));
}
else if constexpr(is_same<T, double>{})
return __DBL_MAX__;
else if constexpr(is_same<T, float>{})
return __FLT_MAX__;
else if constexpr(is_same<T, migraphx::half>{})
return __FLT16_MAX__;
else
return 0;
}
template <class T>
constexpr T numeric_lowest()
{
if constexpr(is_integral<T>{})
{
if constexpr(is_unsigned<T>{})
return 0;
else
return -numeric_max<T>() - 1;
}
else
{
return -numeric_max<T>();
}
}
#define MIGRAPHX_REQUIRES(...) class = enable_if_t<__VA_ARGS__>
} // namespace migraphx
......
......@@ -181,11 +181,6 @@ struct miopen_apply
add_extend_op("pad");
add_extend_op("pooling");
add_extend_op("prefix_scan_sum");
add_extend_op("reduce_max");
add_extend_op("reduce_mean");
add_extend_op("reduce_min");
add_extend_op("reduce_prod");
add_extend_op("reduce_sum");
add_extend_op("reverse");
add_extend_op("rnn_var_sl_last_output");
add_extend_op("rnn_var_sl_shift_output");
......
......@@ -17,7 +17,7 @@ struct test_reduce_op_large : verify_program<test_reduce_op_large<Op, Axis, T>>
auto* mm = p.get_main_module();
migraphx::shape s{T, {3, 1026, 4, 3}};
auto x = mm->add_parameter("x", s);
mm->add_instruction(Op{{1}}, x);
mm->add_instruction(Op{{Axis}}, x);
return p;
};
};
......
......@@ -15,13 +15,14 @@ struct test_reduce_op_small : verify_program<test_reduce_op_small<Op, Axis, T>>
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{T, {3, 4, 8, 8}};
migraphx::shape s{T, {3, 4, 2, 2}};
auto x = mm->add_parameter("x", s);
mm->add_instruction(Op{{1}}, x);
mm->add_instruction(Op{{Axis}}, x);
return p;
};
};
template struct test_reduce_op_small<migraphx::op::reduce_sum, 1, migraphx::shape::float_type>;
template struct test_reduce_op_small<migraphx::op::reduce_sum, 2, migraphx::shape::int32_type>;
template struct test_reduce_op_small<migraphx::op::reduce_mean, 2, migraphx::shape::int32_type>;
template struct test_reduce_op_small<migraphx::op::reduce_max, 2, migraphx::shape::int32_type>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment