Unverified Commit ddbbe54b authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Refactor vectorization and preloading for pointwise fusions (#1184)

Improves performance for add_gelu.  In bert it is 4x faster and for mul_add it is 50% faster than what we current have.
parent f55d7c24
......@@ -22,6 +22,7 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DEBUG);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_OPTIMIZE);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_ASM);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_SRC);
#if MIGRAPHX_USE_HIPRTC
......@@ -247,6 +248,16 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
MIGRAPHX_THROW("Missing hsaco");
};
if(enabled(MIGRAPHX_GPU_DUMP_SRC{}))
{
for(const auto& src : srcs)
{
if(src.path.extension() != ".cpp")
continue;
std::cout << std::string(src.content.first, src.len()) << std::endl;
}
}
if(enabled(MIGRAPHX_GPU_DUMP_ASM{}))
{
......
......@@ -6,6 +6,7 @@
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
......@@ -28,7 +29,8 @@ ${preamble}
extern "C" {
__global__ void kernel(${params})
{
pointwise(${lambda}, ${args});
auto idx = make_index();
pointwise(idx, auto_preload<${preloads}>(idx), vectorize<${vec_size}, ${axis}>())(${lambda}, ${args});
}
}
......@@ -41,40 +43,105 @@ struct pointwise_compiler : compiler<pointwise_compiler>
{
std::vector<std::string> names() const { return {"pointwise"}; }
static std::size_t oversubscribe(const std::vector<shape>& inputs)
static std::size_t oversubscribe_if(bool b)
{
if(std::any_of(inputs.begin(), inputs.end(), [](const auto& s) { return s.broadcasted(); }))
return 1;
else
if(b)
return 256;
else
return 1;
}
static std::size_t find_fast_axis(const std::vector<shape>& inputs)
{
auto permutation = find_permutation(inputs);
auto it = std::max_element(permutation.begin(), permutation.end());
return it - permutation.begin();
}
static std::size_t vectorize_elements(const std::vector<shape>& inputs)
static std::vector<bool> preload(std::size_t axis, const std::vector<shape>& inputs)
{
std::size_t n = inputs.front().elements();
const std::size_t max_lds_bytes = 4096;
std::vector<bool> result;
std::transform(inputs.begin(),
inputs.end(),
std::back_inserter(result),
[&](const shape& input) { return input.strides()[axis] == 0; });
auto bytes = std::inner_product(inputs.begin(),
inputs.end(),
result.begin(),
std::size_t{0},
std::plus<>{},
[](const shape& s, bool b) -> std::size_t {
if(b)
return s.bytes();
return 0;
});
if(bytes < max_lds_bytes)
return result;
// TODO: Try to partially preload items
std::fill(result.begin(), result.end(), false);
return result;
}
static std::string preload_str(const std::vector<bool>& bs)
{
std::vector<std::string> bool_strs;
std::transform(bs.begin(), std::prev(bs.end()), std::back_inserter(bool_strs), [](bool b) {
if(b)
return "true";
return "false";
});
return "false, " + join_strings(bool_strs, ", ");
}
static std::vector<std::size_t> vector_sizes(const std::vector<shape>& inputs)
{
// If all inputs is half then only use half2
if(std::all_of(inputs.begin(), inputs.end(), [](const auto& s) {
return s.packed() or s.broadcasted();
return s.type() == shape::half_type;
}))
{
if((n % 4) == 0)
return n / 4;
else if((n % 2) == 0)
return n / 2;
}
return n;
return {2};
return {4, 2};
}
static auto vectorize_elements(std::size_t axis, const std::vector<shape>& inputs)
{
auto sizes = vector_sizes(inputs);
std::vector<std::size_t> max_vec_size;
std::transform(inputs.begin(),
inputs.end(),
std::back_inserter(max_vec_size),
[&](const auto& input) -> std::size_t {
auto stride = input.strides()[axis];
auto len = input.lens()[axis];
if(stride != 0 and stride != 1)
return 1;
auto it = std::find_if(
sizes.begin(), sizes.end(), [&](auto i) { return (len % i) == 0; });
if(it != sizes.end())
return *it;
return 1;
});
return *std::min_element(max_vec_size.begin(), max_vec_size.end());
}
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
hip_compile_options options;
options.set_launch_params(
v, compute_global_for(ctx, vectorize_elements(inputs), oversubscribe(inputs)));
options.inputs = inputs;
options.output = inputs.back();
options.virtual_inputs = reduce_dims(inputs);
options.params = "-Wno-float-equal";
auto src = interpolate_string(pointwise_kernel,
auto axis = find_fast_axis(options.virtual_inputs);
auto vec_size = vectorize_elements(axis, options.virtual_inputs);
auto preloads = preload(axis, options.virtual_inputs);
auto is_preloading =
std::accumulate(preloads.begin(), preloads.end(), false, std::logical_or<>{});
options.set_launch_params(v,
compute_global_for(ctx,
options.output.elements() / vec_size,
oversubscribe_if(not is_preloading)));
auto src = interpolate_string(pointwise_kernel,
{{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"lambda", v.at("lambda").to<std::string>()},
{"vec_size", std::to_string(vec_size)},
{"axis", std::to_string(axis)},
{"preloads", preload_str(preloads)},
{"preamble", v.get("preamble", std::string{})}});
return compile_hip_code_object(src, options);
}
......
......@@ -3,6 +3,14 @@
#include <migraphx/kernels/array.hpp>
// NOLINTNEXTLINE
#define MIGRAPHX_RETURNS(...) \
->decltype(__VA_ARGS__) { return __VA_ARGS__; }
// NOLINTNEXTLINE
#define MIGRAPHX_LIFT(...) \
[](auto&&... xs) MIGRAPHX_RETURNS((__VA_ARGS__)(static_cast<decltype(xs)>(xs)...))
namespace migraphx {
struct swallow
......@@ -161,6 +169,18 @@ constexpr auto pack(Ts... xs)
return [=](auto f) { return f(xs...); };
}
template <class G, class F>
constexpr auto join(G g, F f)
{
return f([=](auto... xs) { return g(xs...); });
}
template <class G, class F, class... Fs>
constexpr auto join(G g, F f, Fs... fs)
{
return f([=](auto... xs) { return join([=](auto... ys) { return g(xs..., ys...); }, fs...); });
}
template <class Compare, class P1, class P2>
constexpr auto pack_compare(Compare compare, P1 p1, P2 p2)
{
......@@ -191,39 +211,45 @@ constexpr auto arg(IntegralConstant ic)
return arg_c<ic>();
}
inline constexpr auto rotate_last()
template <class F>
constexpr auto make_transform(F f)
{
return [](auto... xs) {
return [=](auto&& f) {
return sequence_c<sizeof...(xs)>([&](auto... is) {
constexpr auto size = sizeof...(is);
return f(arg_c<(is + size - 1) % size>()(xs...)...);
});
};
};
return [=](auto... xs) { return [=](auto g) { return f(g, xs...); }; };
}
// An arg transformation takes the arguments and then a function to take the new arguments:
// transform(xs...)([](auto... ys) { ... })
// The transform_args function takes a list of transformations and continually applies them
template <class F>
constexpr auto transform_args(F f)
{
return [=](auto... xs) {
return [=](auto g) { return f(xs...)([&](auto... ys) { return g(ys...); }); };
};
return f;
}
template <class F, class... Fs>
constexpr auto transform_args(F f, Fs... fs)
{
return [=](auto... xs) { return transform_args(f)(xs...)(transform_args(fs...)); };
return make_transform([=](auto g, auto... xs) {
return f(xs...)([=](auto... ys) { return transform_args(fs...)(ys...)(g); });
});
}
// NOLINTNEXTLINE
#define MIGRAPHX_RETURNS(...) \
->decltype(__VA_ARGS__) { return __VA_ARGS__; }
// identity transform
inline constexpr auto transform_args()
{
return make_transform([](auto f, auto... xs) { return f(xs...); });
}
// NOLINTNEXTLINE
#define MIGRAPHX_LIFT(...) \
[](auto&&... xs) MIGRAPHX_RETURNS((__VA_ARGS__)(static_cast<decltype(xs)>(xs)...))
// Rotate the first argument to the last argument
inline constexpr auto rotate_last()
{
return make_transform([](auto f, auto... xs) {
return sequence_c<sizeof...(xs)>([&](auto... is) {
constexpr auto size = sizeof...(is);
return f(arg_c<(is + size - 1) % size>()(xs...)...);
});
});
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
......@@ -38,20 +38,17 @@ constexpr implicit_conversion_op<T> implicit_conversion(T x)
template <class F, class T, class... Ts>
__device__ void pointwise_tensor(index idx, F f, T out, Ts... xs)
{
preload<typename T::type>(idx, xs...)([&](auto... ps) {
idx.global_stride(out.get_shape().elements(),
[&](auto i) { out[i] = implicit_conversion(f(ps[i]...)); });
});
idx.global_stride(out.get_shape().elements(),
[&](auto i) { out[i] = implicit_conversion(f(xs[i]...)); });
}
template <class F, class... Ts>
__device__ void pointwise(F f, Ts*... ps)
template <class... Transforms>
__device__ auto pointwise(index idx, Transforms... transforms)
{
auto t = transform_args(make_tensors(), rotate_last(), auto_vectorize());
t(ps...)([&](auto... xs) {
auto idx = make_index();
pointwise_tensor(idx, f, xs...);
});
return [=](auto f, auto*... ps) {
auto t = transform_args(make_tensors(), rotate_last(), transforms...);
t(ps...)([&](auto... xs) { pointwise_tensor(idx, f, xs...); });
};
}
} // namespace migraphx
......
......@@ -3,6 +3,8 @@
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/vec.hpp>
namespace migraphx {
......@@ -73,7 +75,7 @@ __device__ auto preload_copy(index idx, F f, __shared__ T* buffer, Ts... xs)
{
if constexpr(decltype(tensor_vec_size(x)){} == 0)
{
auto v = vectorize(x);
auto v = auto_vectorize(x);
auto b = as_vec(tensor_vec_size(v), buffer + offset);
idx.local_stride(v.get_shape().element_space(),
[&](auto i) { b[i] = v.data()[i]; });
......@@ -126,5 +128,47 @@ __device__ auto preload(index idx, Ts... xs)
};
}
inline __device__ auto auto_preload(index idx)
{
return make_transform([=](auto f, auto out, auto... xs) {
preload<typename decltype(out)::type>(idx, xs...)([&](auto... ys) { f(out, ys...); });
});
}
template <bool B, class T>
__device__ auto preload_copy(index idx, T x)
{
return [=](auto f) {
if constexpr(B)
{
using type = typename T::type;
constexpr auto size = get_shape_c<T>{}.element_space();
__shared__ type buffer[size];
// TODO: Always vecotrize when size > 4, and then use a second loop for remainder
constexpr auto n = find_vectorize_size([&](auto i) { return (size % i) == 0; });
auto input = as_vec<n>(remove_bool(x.data()));
auto b = as_vec<n>(remove_bool(buffer));
idx.local_stride(size / n, [&](auto i) { b[i] = input[i]; });
return f(x.with(buffer));
}
else
{
return f(x);
}
};
}
template <bool... Bs>
__device__ auto auto_preload(index idx)
{
return make_transform([=](auto f, auto... xs) {
auto invoke = [=](auto... ys) {
__syncthreads();
f(ys...);
};
join(invoke, preload_copy<Bs>(idx, xs)...);
});
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_PRELOAD_HPP
......@@ -60,10 +60,19 @@ constexpr auto common_vec_size()
})(vec_size<Ts>()...);
}
// Bools can not be used as a vector type so convert it to uint8
template <class T>
__device__ __host__ T* remove_bool(T* x)
{
return x;
}
inline __device__ __host__ uint8_t* remove_bool(bool* x) { return reinterpret_cast<uint8_t*>(x); }
template <index_int N, class T>
__device__ __host__ auto as_vec(T* x)
{
if constexpr(N == 0)
if constexpr(N < 2)
return x;
else
return reinterpret_cast<vec<T, N>*>(x);
......
......@@ -50,19 +50,10 @@ constexpr auto shape_step(Shape s, Axis)
});
}
// Bools can not be used as a vector type so convert it to uint8
template <class T>
__device__ __host__ T* remove_bool(T* x)
{
return x;
}
inline __device__ __host__ uint8_t* remove_bool(bool* x) { return reinterpret_cast<uint8_t*>(x); }
template <index_int N, class T, class Axis>
__device__ __host__ auto as_vec(T x, Axis axis)
{
if constexpr(N == 0)
if constexpr(N < 2)
return x;
else
return make_tensor_view(as_vec<N>(remove_bool(x.data())),
......@@ -72,7 +63,7 @@ __device__ __host__ auto as_vec(T x, Axis axis)
template <index_int N, class T, class Axis>
constexpr auto tensor_step(T x, Axis axis)
{
if constexpr(N == 0)
if constexpr(N < 2)
{
return x;
}
......@@ -157,11 +148,11 @@ constexpr auto find_vectorize_size(P pred)
else if constexpr(decltype(pred(_c<2>)){})
return _c<2>;
else
return _c<0>;
return _c<1>;
}
template <class T>
__host__ __device__ auto vectorize(T x)
__host__ __device__ auto auto_vectorize(T x)
{
if constexpr(tensor_vec_size<T>() == 0)
{
......@@ -194,7 +185,7 @@ inline __device__ __host__ auto auto_vectorize_impl(F f, Ts... xs)
{
MIGRAPHX_ASSERT(s.strides[axis] == 0 or s.strides[axis] == 1);
MIGRAPHX_ASSERT(s.lens[axis] > 0);
MIGRAPHX_ASSERT(n == 0 or s.lens[axis] % n == 0);
MIGRAPHX_ASSERT(n == 1 or s.lens[axis] % n == 0);
if constexpr(s.strides[axis] == 0)
return tensor_step<n>(x, axis);
else
......@@ -215,7 +206,32 @@ inline __device__ __host__ auto auto_vectorize_impl(F f, Ts... xs)
inline __device__ __host__ auto auto_vectorize()
{
return [](auto... xs) { return [=](auto f) { auto_vectorize_impl(f, xs...); }; };
return make_transform([](auto f, auto... xs) { auto_vectorize_impl(f, xs...); });
}
template <index_int N, index_int Axis, class T>
__device__ __host__ auto vectorize_tensor(T x)
{
constexpr auto shape = get_shape_c<T>{};
if constexpr(shape.strides[Axis] == 0)
return tensor_step<N>(x, _c<Axis>);
else
return as_vec<N>(x, _c<Axis>);
}
template <index_int N, index_int Axis>
__device__ __host__ auto vectorize()
{
return make_transform([](auto f, auto... xs) {
if constexpr(N < 2)
{
f(xs...);
}
else
{
f(vectorize_tensor<N, Axis>(xs)...);
}
});
}
} // namespace migraphx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment