Unverified Commit 78a3c9b7 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Add auto-vectorization of pointwise operators (#1047)

* Enable auto vectorization
* Handle vector types with convert function
* Dont vectorize when it will cause problems with preload
parent b7218806
#include <migraphx/cpp_generator.hpp> #include <migraphx/cpp_generator.hpp>
#include <migraphx/module.hpp> #include <migraphx/module.hpp>
#include <migraphx/operation.hpp> #include <migraphx/operation.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/builtin.hpp> #include <migraphx/builtin.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
...@@ -51,6 +52,7 @@ cpp_generator::function& cpp_generator::function::set_types(const module& m) ...@@ -51,6 +52,7 @@ cpp_generator::function& cpp_generator::function::set_types(const module& m)
cpp_generator::function& cpp_generator::function&
cpp_generator::function::set_types(const module& m, const std::function<std::string(shape)>& parse) cpp_generator::function::set_types(const module& m, const std::function<std::string(shape)>& parse)
{ {
this->params.clear();
auto pmap = m.get_parameter_shapes(); auto pmap = m.get_parameter_shapes();
std::map<std::string, shape> input_map(pmap.begin(), pmap.end()); std::map<std::string, shape> input_map(pmap.begin(), pmap.end());
std::transform( std::transform(
...@@ -63,11 +65,30 @@ cpp_generator::function::set_types(const module& m, const std::function<std::str ...@@ -63,11 +65,30 @@ cpp_generator::function::set_types(const module& m, const std::function<std::str
return *this; return *this;
} }
cpp_generator::function& cpp_generator::function::set_generic_types(const module& m)
{
this->params.clear();
auto pmap = m.get_parameter_shapes();
std::map<std::string, shape> input_map(pmap.begin(), pmap.end());
std::transform(
input_map.begin(), input_map.end(), std::back_inserter(this->params), [&](auto&& p) {
return param{p.first, "T" + p.first};
});
std::transform(input_map.begin(),
input_map.end(),
std::back_inserter(this->tparams),
[&](auto&& p) { return "class T" + p.first; });
this->return_type = "auto";
return *this;
}
struct cpp_generator_impl struct cpp_generator_impl
{ {
std::stringstream fs{}; std::stringstream fs{};
std::size_t function_count = 0; std::size_t function_count = 0;
std::function<std::string(std::string)> fmap = nullptr; std::function<std::string(std::string)> fmap = nullptr;
std::unordered_map<std::string, std::string> point_op_map = {};
}; };
cpp_generator::cpp_generator() : impl(std::make_unique<cpp_generator_impl>()) {} cpp_generator::cpp_generator() : impl(std::make_unique<cpp_generator_impl>()) {}
...@@ -83,41 +104,54 @@ cpp_generator::~cpp_generator() noexcept = default; ...@@ -83,41 +104,54 @@ cpp_generator::~cpp_generator() noexcept = default;
void cpp_generator::fmap(const std::function<std::string(std::string)>& f) { impl->fmap = f; } void cpp_generator::fmap(const std::function<std::string(std::string)>& f) { impl->fmap = f; }
void cpp_generator::add_point_op(const std::string& op_name, const std::string& code)
{
impl->point_op_map[op_name] = code;
}
std::string cpp_generator::generate_point_op(const operation& op, std::string cpp_generator::generate_point_op(const operation& op,
const std::vector<std::string>& args) const std::vector<std::string>& args)
{ {
auto v = op.to_value(); auto v = op.to_value();
auto attributes = op.attributes(); std::string code;
if(not attributes.contains("point_op")) if(contains(impl->point_op_map, op.name()))
MIGRAPHX_THROW("op is missing point_op attribute: " + op.name()); {
return interpolate_string(attributes["point_op"].to<std::string>(), code = impl->point_op_map.at(op.name());
[&](auto start, auto last) -> std::string { }
auto key = trim({start, last}); else
if(key.empty()) {
MIGRAPHX_THROW("Empty parameter"); auto attributes = op.attributes();
std::string fselector = "function:"; if(not attributes.contains("point_op"))
if(starts_with(key, fselector)) MIGRAPHX_THROW("op is missing point_op attribute: " + op.name());
{ code = attributes["point_op"].to<std::string>();
auto fname = key.substr(fselector.size()); }
if(impl->fmap == nullptr) return interpolate_string(code, [&](auto start, auto last) -> std::string {
return fname; auto key = trim({start, last});
else if(key.empty())
return impl->fmap(fname); MIGRAPHX_THROW("Empty parameter");
} std::string fselector = "function:";
else if(with_char(::isdigit)(key[0])) if(starts_with(key, fselector))
{ {
auto i = std::stoul(key); auto fname = key.substr(fselector.size());
return args.at(i); if(impl->fmap == nullptr)
} return fname;
else if(v.contains(key)) else
{ return impl->fmap(fname);
return v[key].template to<std::string>(); }
} else if(with_char(::isdigit)(key[0]))
else {
{ auto i = std::stoul(key);
return key; return args.at(i);
} }
}); else if(v.contains(key))
{
return v[key].template to<std::string>();
}
else
{
return key;
}
});
} }
std::string cpp_generator::str() const { return impl->fs.str(); } std::string cpp_generator::str() const { return impl->fs.str(); }
...@@ -148,6 +182,8 @@ cpp_generator::function cpp_generator::generate_module(const module& m) ...@@ -148,6 +182,8 @@ cpp_generator::function cpp_generator::generate_module(const module& m)
std::string cpp_generator::create_function(const cpp_generator::function& f) std::string cpp_generator::create_function(const cpp_generator::function& f)
{ {
impl->function_count++; impl->function_count++;
if(not f.tparams.empty())
impl->fs << "template<" << join_strings(f.tparams, ", ") << ">\n";
std::string name = f.name.empty() ? "f" + std::to_string(impl->function_count) : f.name; std::string name = f.name.empty() ? "f" + std::to_string(impl->function_count) : f.name;
impl->fs << join_strings(f.attributes, " ") << " " << f.return_type << " " << name; impl->fs << join_strings(f.attributes, " ") << " " << f.return_type << " " << name;
char delim = '('; char delim = '(';
......
...@@ -34,6 +34,7 @@ struct cpp_generator ...@@ -34,6 +34,7 @@ struct cpp_generator
std::string return_type = "void"; std::string return_type = "void";
std::string name = ""; std::string name = "";
std::vector<std::string> attributes = {}; std::vector<std::string> attributes = {};
std::vector<std::string> tparams = {};
function& set_body(const module& m, const generate_module_callback& g); function& set_body(const module& m, const generate_module_callback& g);
function& set_body(const std::string& s) function& set_body(const std::string& s)
{ {
...@@ -52,6 +53,7 @@ struct cpp_generator ...@@ -52,6 +53,7 @@ struct cpp_generator
} }
function& set_types(const module& m); function& set_types(const module& m);
function& set_types(const module& m, const std::function<std::string(shape)>& parse); function& set_types(const module& m, const std::function<std::string(shape)>& parse);
function& set_generic_types(const module& m);
}; };
cpp_generator(); cpp_generator();
...@@ -66,6 +68,8 @@ struct cpp_generator ...@@ -66,6 +68,8 @@ struct cpp_generator
void fmap(const std::function<std::string(std::string)>& f); void fmap(const std::function<std::string(std::string)>& f);
void add_point_op(const std::string& op_name, const std::string& code);
std::string generate_point_op(const operation& op, const std::vector<std::string>& args); std::string generate_point_op(const operation& op, const std::vector<std::string>& args);
std::string str() const; std::string str() const;
......
...@@ -114,6 +114,7 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option ...@@ -114,6 +114,7 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
options.params += " -DMIGRAPHX_NGLOBAL=" + std::to_string(options.global); options.params += " -DMIGRAPHX_NGLOBAL=" + std::to_string(options.global);
options.params += " -DMIGRAPHX_NLOCAL=" + std::to_string(options.local); options.params += " -DMIGRAPHX_NLOCAL=" + std::to_string(options.local);
options.params += " " + join_strings(compiler_warnings(), " "); options.params += " " + join_strings(compiler_warnings(), " ");
options.params += " -ftemplate-backtrace-limit=0";
options.params += " -Werror"; options.params += " -Werror";
auto cos = compile_hip_src(srcs, std::move(options.params), get_device_name()); auto cos = compile_hip_src(srcs, std::move(options.params), get_device_name());
if(cos.size() != 1) if(cos.size() != 1)
......
...@@ -63,8 +63,16 @@ operation compile_pointwise(context& ctx, const std::vector<shape>& inputs, modu ...@@ -63,8 +63,16 @@ operation compile_pointwise(context& ctx, const std::vector<shape>& inputs, modu
run_passes(m, {eliminate_common_subexpression{}, dead_code_elimination{}}); run_passes(m, {eliminate_common_subexpression{}, dead_code_elimination{}});
cpp_generator g; cpp_generator g;
g.fmap([](const std::string& fname) { return "migraphx::" + fname; }); g.fmap([](const std::string& fname) { return "migraphx::" + fname; });
auto name = g.create_function(g.generate_module(m).set_attributes({"__device__"})); g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})");
return compile_pointwise((ctx), inputs, "&" + name, g.str()); g.add_point_op("prelu", "${function:where}(${0} < 0, ${0} * ${1}, ${0})");
g.add_point_op("sign", "${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))");
g.add_point_op("equal", "migraphx::abs(${0} == ${1})");
g.add_point_op("less", "migraphx::abs(${0} < ${1})");
g.add_point_op("greater", "migraphx::abs(${0} > ${1})");
g.add_point_op("not", "migraphx::abs(not ${0})");
auto name =
g.create_function(g.generate_module(m).set_attributes({"__device__"}).set_generic_types(m));
return compile_pointwise((ctx), inputs, "MIGRAPHX_LIFT(" + name + ")", g.str());
} }
} // namespace gpu } // namespace gpu
......
...@@ -5,6 +5,9 @@ ...@@ -5,6 +5,9 @@
namespace migraphx { namespace migraphx {
#define MIGRAPHX_STRINGIZE_1(...) #__VA_ARGS__
#define MIGRAPHX_STRINGIZE(...) MIGRAPHX_STRINGIZE_1(__VA_ARGS__)
// Workaround hip's broken abort on device code // Workaround hip's broken abort on device code
#ifdef __HIP_DEVICE_COMPILE__ #ifdef __HIP_DEVICE_COMPILE__
// NOLINTNEXTLINE // NOLINTNEXTLINE
...@@ -14,19 +17,67 @@ namespace migraphx { ...@@ -14,19 +17,67 @@ namespace migraphx {
#define MIGRAPHX_HIP_NORETURN [[noreturn]] #define MIGRAPHX_HIP_NORETURN [[noreturn]]
#endif #endif
namespace debug {
struct swallow
{
template <class... Ts>
constexpr swallow(Ts&&...)
{
}
};
template <size_t N>
struct print_buffer
{
char buffer[N + 1] = {0};
char* pos = buffer;
constexpr void append(char c)
{
if(c == 0)
return;
if(pos < buffer + N)
{
*pos = c;
pos++;
}
}
template <size_t M>
constexpr void append(const char (&array)[M])
{
for(int i = 0; i < M; i++)
append(array[i]);
}
};
template <class... Ts>
__host__ __device__ void print(const Ts&... xs)
{
const auto size = (sizeof(xs) + ...);
print_buffer<size> buffer;
swallow{(buffer.append(xs), 0)...};
printf("%s", buffer.buffer);
}
} // namespace debug
// noreturn cannot be used on this function because abort in hip is broken // noreturn cannot be used on this function because abort in hip is broken
template <class T1, class T2, class T3, class T4>
MIGRAPHX_HIP_NORETURN inline __host__ __device__ void MIGRAPHX_HIP_NORETURN inline __host__ __device__ void
assert_fail(const char* assertion, const char* file, unsigned int line, const char* function) assert_fail(const T1& assertion, const T2& file, const T3& line, const T4& function)
{ {
printf("%s:%u: %s: assertion '%s' failed.\n", file, line, function, assertion); // printf is broken on hip with more than one argument, so use a simple print functions instead
debug::print(file, ":", line, ": ", function, ": assertion '", assertion, "' failed.\n");
// printf("%s:%s: %s: assertion '%s' failed.\n", file, line, function, assertion);
abort(); abort();
} }
#ifdef MIGRAPHX_DEBUG #ifdef MIGRAPHX_DEBUG
#define MIGRAPHX_ASSERT(cond) \ #define MIGRAPHX_ASSERT(cond) \
((cond) ? void(0) : [](auto... xs) { \ ((cond) ? void(0) : [](auto&&... private_migraphx_xs) { \
assert_fail(xs...); \ assert_fail(private_migraphx_xs...); \
}(#cond, __FILE__, __LINE__, __PRETTY_FUNCTION__)) }(#cond, __FILE__, MIGRAPHX_STRINGIZE(__LINE__), __PRETTY_FUNCTION__))
#else #else
#define MIGRAPHX_ASSERT(cond) #define MIGRAPHX_ASSERT(cond)
#endif #endif
......
...@@ -137,12 +137,48 @@ constexpr void each_args(F) ...@@ -137,12 +137,48 @@ constexpr void each_args(F)
{ {
} }
template <class F, class T>
constexpr auto fold_impl(F&&, T&& x)
{
return static_cast<T&&>(x);
}
template <class F, class T, class U, class... Ts>
constexpr auto fold_impl(F&& f, T&& x, U&& y, Ts&&... xs)
{
return fold_impl(f, f(static_cast<T&&>(x), static_cast<U&&>(y)), static_cast<Ts&&>(xs)...);
}
template <class F>
constexpr auto fold(F f)
{
return [=](auto&&... xs) { return fold_impl(f, static_cast<decltype(xs)&&>(xs)...); };
}
template <class... Ts> template <class... Ts>
auto pack(Ts... xs) constexpr auto pack(Ts... xs)
{ {
return [=](auto f) { return f(xs...); }; return [=](auto f) { return f(xs...); };
} }
template <class Compare, class P1, class P2>
constexpr auto pack_compare(Compare compare, P1 p1, P2 p2)
{
return p1([&](auto... xs) {
return p2([&](auto... ys) {
auto c = [&](auto x, auto y) -> int {
if(compare(x, y))
return 1;
else if(compare(y, x))
return -1;
else
return 0;
};
return fold([](auto x, auto y) { return x ? x : y; })(c(xs, ys)..., 0);
});
});
}
template <index_int N> template <index_int N>
constexpr auto arg_c() constexpr auto arg_c()
{ {
...@@ -187,7 +223,7 @@ constexpr auto transform_args(F f, Fs... fs) ...@@ -187,7 +223,7 @@ constexpr auto transform_args(F f, Fs... fs)
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_LIFT(...) \ #define MIGRAPHX_LIFT(...) \
([](auto&&... xs) MIGRAPHX_RETURNS((__VA_ARGS__)(static_cast<decltype(xs)>(xs)...)) [](auto&&... xs) MIGRAPHX_RETURNS((__VA_ARGS__)(static_cast<decltype(xs)>(xs)...))
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP #endif // MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
...@@ -13,6 +13,7 @@ struct integral_constant ...@@ -13,6 +13,7 @@ struct integral_constant
using type = integral_constant; using type = integral_constant;
constexpr operator value_type() const noexcept { return value; } constexpr operator value_type() const noexcept { return value; }
constexpr value_type operator()() const noexcept { return value; } constexpr value_type operator()() const noexcept { return value; }
static constexpr type to() { return {}; }
}; };
// NOLINTNEXTLINE // NOLINTNEXTLINE
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <migraphx/kernels/types.hpp> #include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/vec.hpp> #include <migraphx/kernels/vec.hpp>
#include <migraphx/kernels/functional.hpp> #include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/type_traits.hpp>
#include <hip/hip_fp16.h> #include <hip/hip_fp16.h>
#include <hip/math_functions.h> #include <hip/math_functions.h>
...@@ -19,19 +20,30 @@ constexpr T as_float(T x) ...@@ -19,19 +20,30 @@ constexpr T as_float(T x)
} // namespace math } // namespace math
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH(name, fname) \ #define MIGRAPHX_DEVICE_MATH(name, fname) \
template <class... Ts> \ template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(Ts... xs) MIGRAPHX_RETURNS(fname(xs...)) auto __device__ name(Ts... xs) MIGRAPHX_RETURNS(fname(xs...))
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_FOR(type, name, fname) \ #define MIGRAPHX_DEVICE_MATH_VEC(name) \
template <class... Ts> \ template <class... Ts, MIGRAPHX_REQUIRES(is_any_vec<Ts...>())> \
auto __device__ name(type x, Ts... xs) MIGRAPHX_RETURNS(fname(x, xs...)) auto __device__ name(Ts... xs) \
{ \
return vec_transform(xs...)([](auto... ys) { return name(ys...); }); \
}
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_HALF(name, fname) \ #define MIGRAPHX_DEVICE_MATH_FOR(type, name, fname) \
template <class... Ts> \ template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(migraphx::half x, Ts... xs) \ auto __device__ name(type x, Ts... xs)->type \
{ \
return fname(x, xs...); \
}
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_HALF(name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(migraphx::half x, Ts... xs) \
MIGRAPHX_RETURNS(fname(math::as_float(x), math::as_float(xs)...)) MIGRAPHX_RETURNS(fname(math::as_float(x), math::as_float(xs)...))
MIGRAPHX_DEVICE_MATH(abs, ::abs) MIGRAPHX_DEVICE_MATH(abs, ::abs)
...@@ -99,21 +111,51 @@ MIGRAPHX_DEVICE_MATH_HALF(tan, ::tan) ...@@ -99,21 +111,51 @@ MIGRAPHX_DEVICE_MATH_HALF(tan, ::tan)
MIGRAPHX_DEVICE_MATH_HALF(tanh, ::tanh) MIGRAPHX_DEVICE_MATH_HALF(tanh, ::tanh)
template <class T, class U> template <class T, class U>
constexpr auto& max(const T& a, const U& b) constexpr auto where(bool cond, const T& a, const U& b)
{ {
return (a < b) ? b : a; return cond ? a : b;
} }
MIGRAPHX_DEVICE_MATH_VEC(abs)
MIGRAPHX_DEVICE_MATH_VEC(acos)
MIGRAPHX_DEVICE_MATH_VEC(acosh)
MIGRAPHX_DEVICE_MATH_VEC(asin)
MIGRAPHX_DEVICE_MATH_VEC(asinh)
MIGRAPHX_DEVICE_MATH_VEC(atan)
MIGRAPHX_DEVICE_MATH_VEC(atanh)
MIGRAPHX_DEVICE_MATH_VEC(ceil)
MIGRAPHX_DEVICE_MATH_VEC(cos)
MIGRAPHX_DEVICE_MATH_VEC(cosh)
MIGRAPHX_DEVICE_MATH_VEC(erf)
MIGRAPHX_DEVICE_MATH_VEC(exp)
MIGRAPHX_DEVICE_MATH_VEC(floor)
MIGRAPHX_DEVICE_MATH_VEC(log)
MIGRAPHX_DEVICE_MATH_VEC(pow)
MIGRAPHX_DEVICE_MATH_VEC(round)
MIGRAPHX_DEVICE_MATH_VEC(rsqrt)
MIGRAPHX_DEVICE_MATH_VEC(sin)
MIGRAPHX_DEVICE_MATH_VEC(sinh)
MIGRAPHX_DEVICE_MATH_VEC(sqrt)
MIGRAPHX_DEVICE_MATH_VEC(tan)
MIGRAPHX_DEVICE_MATH_VEC(tanh)
MIGRAPHX_DEVICE_MATH_VEC(where)
template <class T, class U> template <class T, class U>
constexpr auto& min(const T& a, const U& b) constexpr auto max(const T& a, const U& b)
{ {
return (a > b) ? b : a; return where(a < b, b, a);
} }
template <class T, class U> template <class T, class U>
constexpr T convert(U x) constexpr auto min(const T& a, const U& b)
{ {
return x; return where(a > b, b, a);
}
template <class T, class U>
constexpr auto convert(U v)
{
return vec_transform(v)([](auto x) -> T { return x; });
} }
} // namespace migraphx } // namespace migraphx
......
...@@ -10,13 +10,38 @@ ...@@ -10,13 +10,38 @@
namespace migraphx { namespace migraphx {
template <class T>
struct implicit_conversion_op
{
T x;
template <index_int N, class U>
constexpr operator vec<U, N>() const
{
static_assert(vec_size<T>() == N, "Vector mismatch size");
return __builtin_convertvector(x, vec<U, N>);
}
template <class U>
constexpr operator U() const
{
return x;
}
};
template <class T>
constexpr implicit_conversion_op<T> implicit_conversion(T x)
{
return {x};
}
template <class F, class T, class... Ts> template <class F, class T, class... Ts>
__device__ void pointwise_tensor(index idx, F f, T out, Ts... xs) __device__ void pointwise_tensor(index idx, F f, T out, Ts... xs)
{ {
preload<typename T::type>(idx, xs...)([&](auto... ps) { preload<typename T::type>(idx, xs...)([&](auto... ps) {
idx.global_stride(out.get_shape().elements(), [&](auto i) { idx.global_stride(out.get_shape().elements(), [&](auto i) {
auto multi_idx = out.get_shape().multi(i); auto multi_idx = out.get_shape().multi(i);
out[multi_idx] = f(ps[multi_idx]...); out[multi_idx] = implicit_conversion(f(ps[multi_idx]...));
}); });
}); });
} }
...@@ -24,7 +49,7 @@ __device__ void pointwise_tensor(index idx, F f, T out, Ts... xs) ...@@ -24,7 +49,7 @@ __device__ void pointwise_tensor(index idx, F f, T out, Ts... xs)
template <class F, class... Ts> template <class F, class... Ts>
__device__ void pointwise(F f, Ts*... ps) __device__ void pointwise(F f, Ts*... ps)
{ {
auto t = transform_args(make_tensors(), rotate_last()); auto t = transform_args(make_tensors(), rotate_last(), auto_vectorize());
t(ps...)([&](auto... xs) { t(ps...)([&](auto... xs) {
auto idx = make_index(); auto idx = make_index();
pointwise_tensor(idx, f, xs...); pointwise_tensor(idx, f, xs...);
......
...@@ -29,7 +29,7 @@ constexpr auto traverse_preload(Shapes... ss) ...@@ -29,7 +29,7 @@ constexpr auto traverse_preload(Shapes... ss)
} }
template <class T, class... Shapes> template <class T, class... Shapes>
constexpr index_int compute_preload_size(Shapes...) constexpr index_int compute_preload_size_c(Shapes...)
{ {
index_int size = 0; index_int size = 0;
traverse_preload<T>(Shapes{}...)( traverse_preload<T>(Shapes{}...)(
...@@ -37,6 +37,12 @@ constexpr index_int compute_preload_size(Shapes...) ...@@ -37,6 +37,12 @@ constexpr index_int compute_preload_size(Shapes...)
return size; return size;
} }
template <class T, class... Shapes>
constexpr auto compute_preload_size(Shapes...)
{
return _c<compute_preload_size_c<T>(Shapes{}...)>;
}
template <class F, class T, class... Ts> template <class F, class T, class... Ts>
__device__ auto preload_copy(index idx, F f, __shared__ T* buffer, Ts... xs) __device__ auto preload_copy(index idx, F f, __shared__ T* buffer, Ts... xs)
{ {
...@@ -48,11 +54,21 @@ __device__ auto preload_copy(index idx, F f, __shared__ T* buffer, Ts... xs) ...@@ -48,11 +54,21 @@ __device__ auto preload_copy(index idx, F f, __shared__ T* buffer, Ts... xs)
[&](auto x, auto offset, auto copy) { [&](auto x, auto offset, auto copy) {
if constexpr(copy) if constexpr(copy)
{ {
auto v = vectorize(x); if constexpr(decltype(tensor_vec_size(x)){} == 0)
auto b = as_vec(tensor_vec_size(v), buffer + offset); {
idx.local_stride(v.get_shape().element_space(), auto v = vectorize(x);
[&](auto i) { b[i] = v.data()[i]; }); auto b = as_vec(tensor_vec_size(v), buffer + offset);
return x.with(buffer + offset); idx.local_stride(v.get_shape().element_space(),
[&](auto i) { b[i] = v.data()[i]; });
return x.with(buffer + offset);
}
else
{
auto b = as_vec(tensor_vec_size(x), buffer + offset);
idx.local_stride(x.get_shape().element_space(),
[&](auto i) { b[i] = x.data()[i]; });
return x.with(b);
}
} }
else else
{ {
...@@ -78,7 +94,7 @@ template <class T, class... Ts> ...@@ -78,7 +94,7 @@ template <class T, class... Ts>
__device__ auto preload(index idx, Ts... xs) __device__ auto preload(index idx, Ts... xs)
{ {
using type = typename remove_vec<T>::type; using type = typename remove_vec<T>::type;
constexpr auto size = compute_preload_size<type>(xs.get_shape()...); constexpr auto size = decltype(compute_preload_size<type>(xs.get_shape()...)){};
const index_int max_size = 512 * sizeof(type); const index_int max_size = 512 * sizeof(type);
return [=](auto f) { return [=](auto f) {
if constexpr(size > 0 and size < max_size) if constexpr(size > 0 and size < max_size)
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraphx/kernels/types.hpp> #include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp> #include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/functional.hpp>
namespace migraphx { namespace migraphx {
...@@ -24,6 +25,38 @@ constexpr auto vec_size() ...@@ -24,6 +25,38 @@ constexpr auto vec_size()
return decltype(vec_size(T{})){}; return decltype(vec_size(T{})){};
} }
template <class... Ts>
constexpr auto is_any_vec()
{
if constexpr(sizeof...(Ts) == 0)
return false_type{};
else
return bool_constant<((vec_size<Ts>() + ...) > 0)>{};
}
template <class T, class I>
constexpr auto vec_at(T x, I i)
{
if constexpr(vec_size<T>() == 0)
return x;
else
{
MIGRAPHX_ASSERT(i < vec_size<T>());
return x[i];
}
}
template <class... Ts>
constexpr auto common_vec_size()
{
return fold([](auto x, auto y) {
if constexpr(x > y)
return x;
else
return y;
})(vec_size<Ts>()...);
}
template <index_int N, class T> template <index_int N, class T>
__device__ __host__ auto as_vec(T* x) __device__ __host__ auto as_vec(T* x)
{ {
...@@ -33,5 +66,25 @@ __device__ __host__ auto as_vec(T* x) ...@@ -33,5 +66,25 @@ __device__ __host__ auto as_vec(T* x)
return reinterpret_cast<vec<T, N>*>(x); return reinterpret_cast<vec<T, N>*>(x);
} }
template <class... Ts>
constexpr auto vec_transform(Ts... xs)
{
return [=](auto f) {
if constexpr(is_any_vec<Ts...>())
{
using type = decltype(f(vec_at(xs, 0)...));
constexpr auto size = common_vec_size<Ts...>();
vec<type, size> result = {0};
for(int i = 0; i < size; i++)
result[i] = f(vec_at(xs, i)...);
return result;
}
else
{
return f(xs...);
}
};
}
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_VEC_HPP #endif // MIGRAPHX_GUARD_KERNELS_VEC_HPP
...@@ -7,40 +7,70 @@ ...@@ -7,40 +7,70 @@
namespace migraphx { namespace migraphx {
template <class T> template <class T>
constexpr auto tensor_vec_size(T) constexpr auto tensor_vec_size()
{ {
return vec_size<typename T::type>(); return vec_size<typename T::type>();
} }
template <index_int N, class Shape> template <class T>
constexpr auto as_vec_shape(Shape s) constexpr auto tensor_vec_size(T)
{ {
auto lens = transform(s.lens, s.strides, [](auto len, auto stride) { return tensor_vec_size<T>();
if(stride == 1) }
return len / N;
else template <index_int N, class Shape, class Axis>
return len; constexpr auto shape_step(Shape s, Axis)
}); {
auto strides = transform(s.strides, [](auto stride) { static_assert(N > 0, "Vector size must be non-zero");
if(stride == 1) return sequence(s.lens.size(), [&](auto... is) {
return stride; auto lens = transform(s.lens, index_ints<is...>{}, [&](auto i, auto j) {
return stride / N; constexpr auto axis = Axis::to();
MIGRAPHX_ASSERT(i != 0);
MIGRAPHX_ASSERT(j != axis or i % N == 0);
if(j == axis)
return i / N;
else
return i;
});
auto strides = transform(s.strides, index_ints<is...>{}, [&](auto i, auto j) {
constexpr auto axis = Axis::to();
// If stride of the axis is zero then we dont need to adjust the other strides
if(Shape{}.strides[axis] == 0)
return i;
MIGRAPHX_ASSERT(j == axis or i % N == 0);
if(j == axis)
return i;
else
return i / N;
});
MIGRAPHX_ASSERT(make_shape(lens, strides).elements() * N == s.elements());
MIGRAPHX_ASSERT(strides[Axis{}] == 0 or
make_shape(lens, strides).element_space() * N == s.element_space());
return make_shape(lens, strides);
}); });
MIGRAPHX_ASSERT(make_shape(lens, strides).element_space() * N == s.element_space());
return make_shape(lens, strides);
} }
template <index_int N, class T> // Bools can not be used as a vector type so convert it to int8
__device__ __host__ auto as_vec(T x) template <class T>
__device__ __host__ T* remove_bool(T* x)
{
return x;
}
inline __device__ __host__ int8_t* remove_bool(bool* x) { return reinterpret_cast<int8_t*>(x); }
template <index_int N, class T, class Axis>
__device__ __host__ auto as_vec(T x, Axis axis)
{ {
if constexpr(N == 0) if constexpr(N == 0)
return x; return x;
else else
return make_tensor_view(as_vec<N>(x.data()), as_vec_shape<N>(x.get_shape())); return make_tensor_view(as_vec<N>(remove_bool(x.data())),
shape_step<N>(x.get_shape(), axis));
} }
template <index_int N, class T, class Axis> template <index_int N, class T, class Axis>
constexpr auto tensor_step(T x, Axis) constexpr auto tensor_step(T x, Axis axis)
{ {
if constexpr(N == 0) if constexpr(N == 0)
{ {
...@@ -49,17 +79,8 @@ constexpr auto tensor_step(T x, Axis) ...@@ -49,17 +79,8 @@ constexpr auto tensor_step(T x, Axis)
else else
{ {
constexpr auto s = decltype(x.get_shape()){}; constexpr auto s = decltype(x.get_shape()){};
MIGRAPHX_ASSERT(s.strides[Axis{}] == 0); MIGRAPHX_ASSERT(s.strides[axis] == 0);
return sequence(x.get_shape().lens.size(), [&](auto... is) { return make_tensor_view(x.data(), shape_step<N>(s, axis));
auto lens = transform(s.lens, index_ints<is...>{}, [&](auto i, auto j) {
constexpr auto axis = Axis{};
if(j == axis)
return i / N;
else
return i;
});
return make_tensor_view(x.data(), make_shape(lens, s.strides));
});
} }
} }
...@@ -69,45 +90,71 @@ __device__ __host__ auto as_vec(IntegralConstant ic, T&& x) ...@@ -69,45 +90,71 @@ __device__ __host__ auto as_vec(IntegralConstant ic, T&& x)
return as_vec<ic>(x); return as_vec<ic>(x);
} }
template <class... Shapes> template <class Shape>
constexpr index_int find_vector_axis(Shapes... ss) constexpr index_int find_vector_axis_c(Shape s)
{ {
// Find the fastest axis that is not broadcasted
index_int axis = 0; index_int axis = 0;
bool b = false; for(index_int i = 1; i < s.lens.size(); i++)
{
if(s.strides[i] == 0)
continue;
if(s.strides[axis] == 0 or
pack_compare(less{}, pack(s.strides[i], s.lens[i]), pack(s.strides[axis], s.lens[axis])))
axis = i;
}
return axis;
}
template <class... Shapes>
constexpr index_int find_vector_axis_c(Shapes... ss)
{
const bool all_broadcasted = (ss.broadcasted() and ...);
index_int axis = 0;
bool b = false;
by([&](auto s) { by([&](auto s) {
if(b) if(b)
return; return;
auto it = find(s.strides.begin(), s.strides.end(), 1); // Skip broadcasted shapes if there are shapes not broadcasted
if(it == s.strides.end()) if(not all_broadcasted and s.broadcasted())
return; return;
axis = it - s.strides.begin(); axis = find_vector_axis_c(s);
b = true; if(s.strides[axis] == 1)
b = true;
})(ss...); })(ss...);
if(not b)
return -1;
return axis; return axis;
} }
template <class... Shapes>
constexpr auto find_vector_axis(Shapes...)
{
return _c<find_vector_axis_c(Shapes{}...)>;
}
template <index_int N, class Axis, class... Shapes> template <index_int N, class Axis, class... Shapes>
constexpr auto is_vectorizable(Axis axis, Shapes... ss) constexpr auto is_vectorizable_c(Axis axis, Shapes... ss)
{ {
return (((ss.lens[axis] % N) == 0 and ss.strides[axis] == 1) and ...); return ((axis < ss.lens.size() and ss.lens[axis] % N == 0 and
// Only vectorize broadcasted types with stride 0, since this causes issues in the
// preloader
((not ss.broadcasted() and ss.strides[axis] == 1) or ss.strides[axis] == 0)) and
...);
} }
template <index_int N, class Shape> template <index_int N, class Axis, class... Shapes>
constexpr bool is_vectorizable(Shape s) constexpr auto is_vectorizable(Axis, Shapes...)
{ {
auto it = find(s.strides.begin(), s.strides.end(), 1); return _c<is_vectorizable_c<N>(Axis::to(), Shapes{}...)>;
if(it == s.strides.end())
return false;
auto axis = it - s.strides.begin();
return (s.lens[axis] % N) == 0 and s.strides[axis] == 1;
} }
template <class P> template <class P>
constexpr auto find_vectorize_size(P pred) constexpr auto find_vectorize_size(P pred)
{ {
if constexpr(pred(_c<4>)) if constexpr(decltype(pred(_c<4>)){})
return _c<4>; return _c<4>;
else if constexpr(pred(_c<2>)) else if constexpr(decltype(pred(_c<2>)){})
return _c<2>; return _c<2>;
else else
return _c<0>; return _c<0>;
...@@ -116,11 +163,12 @@ constexpr auto find_vectorize_size(P pred) ...@@ -116,11 +163,12 @@ constexpr auto find_vectorize_size(P pred)
template <class T> template <class T>
__host__ __device__ auto vectorize(T x) __host__ __device__ auto vectorize(T x)
{ {
if constexpr(vec_size<T>() == 0) if constexpr(tensor_vec_size<T>() == 0)
{ {
constexpr auto axis = find_vector_axis(x.get_shape());
constexpr auto n = constexpr auto n =
find_vectorize_size([&](auto i) { return _c<is_vectorizable<i>(x.get_shape())>; }); find_vectorize_size([&](auto i) { return is_vectorizable<i>(axis, x.get_shape()); });
return as_vec<n>(x); return as_vec<n>(x, axis);
} }
else else
{ {
...@@ -128,34 +176,46 @@ __host__ __device__ auto vectorize(T x) ...@@ -128,34 +176,46 @@ __host__ __device__ auto vectorize(T x)
} }
} }
template <class F, class... Ts>
inline __device__ __host__ auto auto_vectorize_impl(F f, Ts... xs)
{
// TODO: Just check there a single axis of 1
constexpr bool packed_or_broadcasted =
((xs.get_shape().packed() or xs.get_shape().broadcasted()) and ...);
if constexpr(packed_or_broadcasted)
{
constexpr auto axis = decltype(find_vector_axis(xs.get_shape()...)){};
constexpr auto n = find_vectorize_size(
[&](auto i) { return is_vectorizable<i>(axis, xs.get_shape()...); });
by(
[&](auto x) {
constexpr auto s = decltype(x.get_shape()){};
if constexpr(axis < s.strides.size())
{
MIGRAPHX_ASSERT(s.strides[axis] == 0 or s.strides[axis] == 1);
MIGRAPHX_ASSERT(s.lens[axis] > 0);
MIGRAPHX_ASSERT(n == 0 or s.lens[axis] % n == 0);
if constexpr(s.strides[axis] == 0)
return tensor_step<n>(x, axis);
else
return as_vec<n>(x, axis);
}
else
{
return x;
}
},
f)(xs...);
}
else
{
f(xs...);
}
}
inline __device__ __host__ auto auto_vectorize() inline __device__ __host__ auto auto_vectorize()
{ {
return [](auto... xs) { return [](auto... xs) { return [=](auto f) { auto_vectorize_impl(f, xs...); }; };
return [=](auto f) {
// TODO: Just check there a single axis of 1
constexpr bool packed_or_broadcasted =
((xs.get_shape().packed() or xs.get_shape().broadcasted()) and ...);
if constexpr(packed_or_broadcasted)
{
constexpr auto axis = find_vector_axis(xs.get_shape()...);
constexpr auto n = find_vectorize_size(
[&](auto i) { return _c<is_vectorizable<i>(axis, xs.get_shape()...)>; });
by(
[&](auto x) {
constexpr auto s = x.get_shape();
if constexpr(s.strides[axis] == 0)
return tensor_step<n>(x, axis);
else
return as_vec<n>(x);
},
f)(xs...);
}
else
{
f(xs...);
}
};
};
} }
} // namespace migraphx } // namespace migraphx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment