Commit d5a32cd2 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge changes from device refactor.

parents 89dfc4dd 8be483c5
#ifndef MIGRAPHX_GUARD_RTGLIB_ARRAY_HPP
#define MIGRAPHX_GUARD_RTGLIB_ARRAY_HPP
#include <migraphx/config.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/requires.hpp>
#include <type_traits>
#include <array>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace detail {
template <class R, class...>
struct array_type
{
using type = R;
};
template <class... Ts>
struct array_type<void, Ts...> : std::common_type<Ts...>
{
};
template <class R, class... Ts>
using array_type_t = typename array_type<R, Ts...>::type;
template <class T, std::size_t N, std::size_t... I>
constexpr std::array<std::remove_cv_t<T>, N> to_array_impl(T (&a)[N], seq<I...>)
{
return {{a[I]...}};
}
} // namespace detail
template <class Result = void, class... Ts, MIGRAPHX_REQUIRES((sizeof...(Ts) > 0))>
constexpr std::array<detail::array_type_t<Result, Ts...>, sizeof...(Ts)> make_array(Ts&&... xs)
{
return {static_cast<detail::array_type_t<Result, Ts...>>(std::forward<Ts>(xs))...};
}
constexpr std::array<int, 0> make_array() { return {}; }
template <class T, std::size_t N>
constexpr auto to_array(T (&a)[N])
{
return detail::to_array_impl(a, detail::gens<N>{});
}
namespace detail {
template <std::size_t Offset = 0, class Array, std::size_t... I>
constexpr auto rearray_impl(Array a, seq<I...>)
{
return make_array(a[I + Offset]...);
}
} // namespace detail
template <class T, std::size_t N>
constexpr auto pop_front(std::array<T, N> a)
{
return detail::rearray_impl(a, detail::gens<N - 1>{});
}
template <class T, std::size_t N>
constexpr auto pop_back(std::array<T, N> a)
{
return detail::rearray_impl<1>(a, detail::gens<N - 1>{});
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -15,6 +15,12 @@ struct swallow ...@@ -15,6 +15,12 @@ struct swallow
} }
}; };
template <class T>
auto tuple_size(const T&)
{
return typename std::tuple_size<T>::type{};
}
namespace detail { namespace detail {
template <class R, class F> template <class R, class F>
...@@ -83,6 +89,12 @@ constexpr auto sequence_c(F&& f) ...@@ -83,6 +89,12 @@ constexpr auto sequence_c(F&& f)
return detail::sequence_c_impl(f, detail::gens<N>{}); return detail::sequence_c_impl(f, detail::gens<N>{});
} }
template <class IntegerConstant, class F>
constexpr auto sequence(IntegerConstant ic, F&& f)
{
return sequence_c<ic>(f);
}
template <class F, class... Ts> template <class F, class... Ts>
constexpr void each_args(F f, Ts&&... xs) constexpr void each_args(F f, Ts&&... xs)
{ {
...@@ -95,9 +107,9 @@ constexpr void each_args(F) ...@@ -95,9 +107,9 @@ constexpr void each_args(F)
} }
template <class F, class T> template <class F, class T>
auto unpack(F f, T& x) auto unpack(F f, T&& x)
{ {
return sequence_c<std::tuple_size<T>{}>([&](auto... is) { f(std::get<is>(x)...); }); return sequence(tuple_size(x), [&](auto... is) { f(std::get<is>(static_cast<T&&>(x))...); });
} }
/// Implements a fix-point combinator /// Implements a fix-point combinator
...@@ -149,6 +161,35 @@ auto index_of(T& x) ...@@ -149,6 +161,35 @@ auto index_of(T& x)
return [&](auto&& y) { return x[y]; }; return [&](auto&& y) { return x[y]; };
} }
template <class T, class... Ts>
decltype(auto) front_args(T&& x, Ts&&...)
{
return static_cast<T&&>(x);
}
template <class... Ts>
decltype(auto) back_args(Ts&&... xs)
{
return std::get<sizeof...(Ts) - 1>(std::tuple<Ts&&...>(static_cast<Ts&&>(xs)...));
}
template <class T, class... Ts>
auto pop_front_args(T&&, Ts&&... xs)
{
return [&](auto f) { f(static_cast<Ts&&>(xs)...); };
}
template <class... Ts>
auto pop_back_args(Ts&&... xs)
{
return [&](auto f) {
using tuple_type = std::tuple<Ts&&...>;
auto t = tuple_type(static_cast<Ts&&>(xs)...);
return sequence_c<sizeof...(Ts) - 1>(
[&](auto... is) { return f(std::get<is>(static_cast<tuple_type&&>(t))...); });
};
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -33,6 +33,10 @@ auto generic_find_impl(rank<0>, C&& c, const T& x) ...@@ -33,6 +33,10 @@ auto generic_find_impl(rank<0>, C&& c, const T& x)
return std::find(c.begin(), c.end(), x); return std::find(c.begin(), c.end(), x);
} }
struct empty
{
};
} // namespace detail } // namespace detail
template <class C, class T> template <class C, class T>
...@@ -71,6 +75,12 @@ bool all_of(const std::initializer_list<T>& c, const Predicate& p) ...@@ -71,6 +75,12 @@ bool all_of(const std::initializer_list<T>& c, const Predicate& p)
return std::all_of(c.begin(), c.end(), p); return std::all_of(c.begin(), c.end(), p);
} }
template <class Predicate>
bool all_of(detail::empty, const Predicate&)
{
return true;
}
template <class C, class Predicate> template <class C, class Predicate>
bool any_of(const C& c, const Predicate& p) bool any_of(const C& c, const Predicate& p)
{ {
...@@ -83,6 +93,12 @@ bool any_of(const std::initializer_list<T>& c, const Predicate& p) ...@@ -83,6 +93,12 @@ bool any_of(const std::initializer_list<T>& c, const Predicate& p)
return std::any_of(c.begin(), c.end(), p); return std::any_of(c.begin(), c.end(), p);
} }
template <class Predicate>
bool any_of(detail::empty, const Predicate&)
{
return false;
}
template <class C, class Predicate> template <class C, class Predicate>
bool none_of(const C& c, const Predicate& p) bool none_of(const C& c, const Predicate& p)
{ {
...@@ -95,6 +111,12 @@ bool none_of(const std::initializer_list<T>& c, const Predicate& p) ...@@ -95,6 +111,12 @@ bool none_of(const std::initializer_list<T>& c, const Predicate& p)
return std::none_of(c.begin(), c.end(), p); return std::none_of(c.begin(), c.end(), p);
} }
template <class Predicate>
bool none_of(detail::empty, const Predicate&)
{
return true;
}
template <class Range, class Iterator> template <class Range, class Iterator>
void copy(Range&& r, Iterator it) void copy(Range&& r, Iterator it)
{ {
......
...@@ -212,6 +212,25 @@ auto visit_all(T&& x, Ts&&... xs) ...@@ -212,6 +212,25 @@ auto visit_all(T&& x, Ts&&... xs)
}; };
} }
template <class T>
auto visit_all(const std::vector<T>& x)
{
auto&& s = x.front().get_shape();
if(!std::all_of(
x.begin(), x.end(), [&](const T& y) { return y.get_shape().type() == s.type(); }))
MIGRAPHX_THROW("Types must be the same");
return [&](auto v) {
s.visit_type([&](auto as) {
using type = typename decltype(as)::type;
std::vector<tensor_view<type>> result;
std::transform(x.begin(), x.end(), std::back_inserter(result), [&](const auto& y) {
return make_view(y.get_shape(), as.from(y.data()));
});
v(result);
});
};
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -10,22 +10,20 @@ namespace gpu { ...@@ -10,22 +10,20 @@ namespace gpu {
namespace device { namespace device {
argument concat(hipStream_t stream, argument concat(hipStream_t stream,
const migraphx::shape& output_shape, const migraphx::shape&,
std::vector<migraphx::argument> args, std::vector<migraphx::argument> args,
std::vector<std::size_t> offsets) std::vector<std::size_t> offsets)
{ {
for(std::size_t l = 0; l < args.size() - 1; l++) auto ninputs = args.size() - 1;
for(std::size_t j = 0; j < ninputs; j++)
{ {
auto argl = args[l]; auto&& arg = args[j];
std::size_t nelements = argl.get_shape().elements(); std::size_t nelements = arg.get_shape().elements();
visit_all(args.back(), argl)([&](auto output, auto input) { auto offset = offsets[j];
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) { hip_visit_all(args.back(), arg)([&](auto output, auto input) {
auto* outptr = output.data() + offsets[l]; gs_launch(stream, nelements)([=](auto i) {
const auto* inptr = input.data(); auto idx = output.get_shape().index(input.get_shape().multi(i));
hip_tensor_descriptor<ndim> desc_input(input.get_shape()); output.data()[idx + offset] = input.data()[i];
hip_tensor_descriptor<ndim> desc_output(output.get_shape());
gs_launch(stream, nelements)(
[=](auto i) { outptr[desc_output.linear(desc_input.multi(i))] = inptr[i]; });
}); });
}); });
} }
......
...@@ -11,35 +11,30 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -11,35 +11,30 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
argument gather(hipStream_t stream, argument gather(hipStream_t stream, argument result, argument arg1, argument arg2, int axis)
const migraphx::shape& output_shape,
std::vector<migraphx::argument> args,
int axis)
{ {
auto axis_index = (axis < 0) ? (axis + args[0].get_shape().lens().size()) : axis; auto axis_index = (axis < 0) ? (axis + arg1.get_shape().lens().size()) : axis;
visit_all(args.back(), args[0])([&](auto output, auto input) { auto& input_shape = arg1.get_shape();
std::size_t nelements = output_shape.elements(); auto lens = input_shape.lens();
args[1].visit([&](auto indices) { lens[axis_index] = arg2.get_shape().elements();
const auto* indices_ptr = device_cast(indices.data()); shape out_comp_shape{result.get_shape().type(), lens};
auto* out_ptr = device_cast(output.data()); std::size_t nelements = result.get_shape().elements();
const auto* in_ptr = device_cast(input.data());
auto& input_shape = args[0].get_shape(); visit_all(result, arg1)([&](auto output, auto input_v) {
auto lens = input_shape.lens(); hip_visit_views(input_v, out_comp_shape)([&](auto input, auto out_comp) {
lens[axis_index] = args[1].get_shape().elements(); arg2.visit([&](auto indices) {
migraphx::shape out_comp_shape{output_shape.type(), lens}; const auto* indices_ptr = device_cast(indices.data());
visit_tensor_size(out_comp_shape.lens().size(), [&](auto n_out_dim) { auto* output_ptr = device_cast(output.data());
hip_tensor_descriptor<n_out_dim> desc_input(input_shape); gs_launch(stream, nelements)([=](auto i) {
hip_tensor_descriptor<n_out_dim> desc_output(out_comp_shape); auto idx = out_comp.multi(i);
gs_launch(stream, nelements)([=](auto ii) { idx[axis_index] = indices_ptr[idx[axis_index]];
auto in_idx = desc_output.multi(ii); output_ptr[i] = input[idx];
in_idx[axis_index] = indices_ptr[in_idx[axis_index]];
out_ptr[ii] = in_ptr[desc_input.linear(in_idx)];
}); });
}); });
}); });
}); });
return args.back(); return result;
} }
} // namespace device } // namespace device
......
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_ARRAY_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_ARRAY_HPP
#include <migraphx/gpu/device/types.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
template <class T, std::size_t N>
struct hip_array
{
T d[N];
MIGRAPHX_DEVICE_CONSTEXPR T& operator[](std::size_t i) { return d[i]; }
MIGRAPHX_DEVICE_CONSTEXPR const T& operator[](std::size_t i) const { return d[i]; }
MIGRAPHX_DEVICE_CONSTEXPR T* data() { return d; }
MIGRAPHX_DEVICE_CONSTEXPR const T* data() const { return d; }
MIGRAPHX_DEVICE_CONSTEXPR std::integral_constant<std::size_t, N> size() const { return {}; }
MIGRAPHX_DEVICE_CONSTEXPR T* begin() { return d; }
MIGRAPHX_DEVICE_CONSTEXPR const T* begin() const { return d; }
MIGRAPHX_DEVICE_CONSTEXPR T* end() { return d + size(); }
MIGRAPHX_DEVICE_CONSTEXPR const T* end() const { return d + size(); }
MIGRAPHX_DEVICE_CONSTEXPR T dot(const hip_array& x) const
{
T result = 0;
for(std::size_t i = 0; i < N; i++)
result += x[i] * d[i];
return result;
}
MIGRAPHX_DEVICE_CONSTEXPR T product() const
{
T result = 1;
for(std::size_t i = 0; i < N; i++)
result *= d[i];
return result;
}
friend MIGRAPHX_DEVICE_CONSTEXPR hip_array operator*(const hip_array& x, const hip_array& y)
{
hip_array result;
for(std::size_t i = 0; i < N; i++)
result[i] = x[i] * y[i];
return result;
}
};
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_NARY_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_NARY_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_NARY_HPP #define MIGRAPHX_GUARD_RTGLIB_DEVICE_NARY_HPP
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp> #include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp> #include <migraphx/gpu/device/visit.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/array.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
...@@ -13,57 +13,30 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -13,57 +13,30 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
template <class T>
using vec4 = T __attribute__((ext_vector_type(4)));
template <class T>
__device__ __host__ vec4<T>* as_vec4(T* x)
{
return reinterpret_cast<vec4<T>*>(x);
}
template <class T>
__device__ __host__ T* as_pointer(vec4<T>* x)
{
return reinterpret_cast<T*>(x);
}
template <class... Ts> template <class... Ts>
auto pack_vec4(Ts... xs) auto pack(Ts... xs) __device__
{ {
return [=](auto f, std::size_t n) { return f(as_vec4(xs)[n]...); }; return [=](auto f) { return f(xs...); };
} }
template <class F, class... Arguments> template <class F, class... Arguments>
auto nary_nonstandard_impl(hipStream_t stream, F f, argument result, Arguments... args) auto nary_nonstandard_impl(hipStream_t stream, F f, argument result, Arguments... args)
{ {
const auto& output_shape = result.get_shape(); std::size_t nelements = result.get_shape().elements();
visit_all(result, args...)([&](auto output, auto... inputs) { hip_visit_all(result, args...)([&](auto output, auto... inputs) {
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) { gs_launch(stream, nelements)([=](auto i) {
auto data = pack(std::make_pair(hip_tensor_descriptor<ndim>{inputs.get_shape()}, auto idx = output.get_shape().multi(i);
device_cast(inputs.data()))...); output[i] = f(inputs[idx]...);
hip_tensor_descriptor<ndim> out_desc(output_shape);
auto* outp = device_cast(output.data());
gs_launch(stream, output_shape.elements())([=](auto i) {
data([&](auto&&... ps) {
auto outidx = out_desc.multi(i);
outp[i] = f(ps.second[ps.first.linear(outidx)]...);
});
});
}); });
}); });
} }
template <class F> template <class F, class... Arguments>
void trinary_broadcast_vec_impl(hipStream_t stream, void nary_broadcast_vec_impl(
F f, hipStream_t stream, F f, argument result, argument barg, Arguments... args)
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3)
{ {
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
const auto& b_shape = arg3.get_shape(); const auto& b_shape = barg.get_shape();
auto bdim = auto bdim =
std::distance(b_shape.strides().begin(), std::distance(b_shape.strides().begin(),
std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) { std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) {
...@@ -73,156 +46,45 @@ void trinary_broadcast_vec_impl(hipStream_t stream, ...@@ -73,156 +46,45 @@ void trinary_broadcast_vec_impl(hipStream_t stream,
auto bdim_stride = output_shape.strides()[bdim]; auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len; auto bdim_next_stride = bdim_stride * bdim_len;
visit_all(result, arg1, arg2, arg3)([&](auto output, auto input1, auto input2, auto input3) { const std::size_t vec_size = 4;
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>; const std::size_t nlocal = 1024;
auto* xp = as_vec4(device_cast(input1.data())); const std::size_t nglobal = 256 * nlocal;
auto* yp = as_vec4(device_cast(input2.data())); const std::size_t bdim_vec_len = bdim_len / vec_size;
auto* zp = as_vec4(device_cast(input3.data())); hip_vec_visit_all<vec_size>(result, barg, args...)(
auto* outp = as_vec4(device_cast(output.data())); [&](auto output, auto binput, auto... inputs) {
using type = typename decltype(output)::value_type;
const std::size_t vec_size = 4; const std::size_t nelements = output.size() / vec_size;
const std::size_t nlocal = 1024; launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
const std::size_t nglobal = 256 * nlocal;
const std::size_t n = output.size() / vec_size; MIGRAPHX_DEVICE_SHARED type buffer[2048 / vec_size];
const std::size_t bdim_vec_len = bdim_len / vec_size; // Load bias into LDS
for(size_t i = idx.local; i < bdim_vec_len; i += nlocal)
launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPHX_DEVICE_SHARED vec4<type> buffer[2048 / vec_size];
// Load bias into LDS
for(size_t i = idx.local; i < bdim_vec_len; i += nlocal)
{
buffer[i] = zp[i];
}
__syncthreads();
auto* bp = as_pointer(buffer);
// Process the data
for(size_t i = idx.global; i < n; i += nglobal)
{
auto bidx = ((i * vec_size) % bdim_next_stride) / bdim_stride;
auto b = bp[bidx];
vec4<type> x = xp[i];
vec4<type> y = yp[i];
vec4<type> out = outp[i];
for(std::size_t j = 0; j < vec_size; j++)
{ {
out[j] = f(x[j], y[j], b); buffer[i] = binput.data()[i];
} }
outp[i] = out; __syncthreads();
} auto* bp = as_pointer(buffer);
}); // Process the data
}); for(size_t i = idx.global; i < nelements; i += nglobal)
}
template <class F>
void trinary_broadcast_impl(hipStream_t stream,
F f,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3)
{
const auto& output_shape = result.get_shape();
const auto& b_shape = arg3.get_shape();
auto bdim =
std::distance(b_shape.strides().begin(),
std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) {
return x != 0;
}));
auto bdim_len = output_shape.lens()[bdim];
auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len;
visit_all(result, arg1, arg2, arg3)([&](auto output, auto input1, auto input2, auto input3) {
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
auto* xp = device_cast(input1.data());
auto* yp = device_cast(input2.data());
auto* zp = device_cast(input3.data());
auto* outp = device_cast(output.data());
const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal;
const std::size_t n = output.size();
launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPHX_DEVICE_SHARED type buffer[2048];
// Load bias into LDS
for(size_t i = idx.local; i < bdim_len; i += nlocal)
{
buffer[i] = zp[i];
}
__syncthreads();
// Process the data
for(size_t i = idx.global; i < n; i += nglobal)
{
auto bidx = (i % bdim_next_stride) / bdim_stride;
auto b = buffer[bidx];
type x = xp[i];
type y = yp[i];
outp[i] = f(x, y, b);
}
});
});
}
template <class F>
void binary_broadcast_vec_impl(
hipStream_t stream, F f, const argument& result, const argument& arg1, const argument& arg2)
{
const auto& output_shape = result.get_shape();
const auto& b_shape = arg2.get_shape();
auto bdim =
std::distance(b_shape.strides().begin(),
std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) {
return x != 0;
}));
auto bdim_len = output_shape.lens()[bdim];
auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len;
visit_all(result, arg1, arg2)([&](auto output, auto input1, auto input2) {
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
auto* xp = as_vec4(device_cast(input1.data()));
auto* yp = as_vec4(device_cast(input2.data()));
auto* outp = as_vec4(device_cast(output.data()));
const std::size_t vec_size = 4;
const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal;
const std::size_t n = output.size() / vec_size;
const std::size_t bdim_vec_len = bdim_len / vec_size;
launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPHX_DEVICE_SHARED vec4<type> buffer[2048 / vec_size];
// Load bias into LDS
for(size_t i = idx.local; i < bdim_vec_len; i += nlocal)
{
buffer[i] = yp[i];
}
__syncthreads();
auto* bp = as_pointer(buffer);
// Process the data
for(size_t i = idx.global; i < n; i += nglobal)
{
auto bidx = ((i * vec_size) % bdim_next_stride) / bdim_stride;
auto b = bp[bidx];
vec4<type> x = xp[i];
vec4<type> out = outp[i];
for(std::size_t j = 0; j < vec_size; j++)
{ {
out[j] = f(x[j], b); auto bidx = ((i * vec_size) % bdim_next_stride) / bdim_stride;
auto b = bp[bidx];
auto out = output.data()[i];
for(std::size_t j = 0; j < vec_size; j++)
{
out[j] = f(inputs.data()[i][j]..., b);
}
output.data()[i] = out;
} }
outp[i] = out; });
}
}); });
});
} }
template <class F> template <class F, class... Arguments>
void binary_broadcast_impl( void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg, Arguments... args)
hipStream_t stream, F f, const argument& result, const argument& arg1, const argument& arg2)
{ {
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
const auto& b_shape = arg2.get_shape(); const auto& b_shape = barg.get_shape();
auto bdim = auto bdim =
std::distance(b_shape.strides().begin(), std::distance(b_shape.strides().begin(),
std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) { std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) {
...@@ -232,31 +94,25 @@ void binary_broadcast_impl( ...@@ -232,31 +94,25 @@ void binary_broadcast_impl(
auto bdim_stride = output_shape.strides()[bdim]; auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len; auto bdim_next_stride = bdim_stride * bdim_len;
visit_all(result, arg1, arg2)([&](auto output, auto input1, auto input2) { const std::size_t nlocal = 1024;
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>; const std::size_t nglobal = 256 * nlocal;
auto* xp = device_cast(input1.data()); std::size_t nelements = result.get_shape().elements();
auto* yp = device_cast(input2.data()); hip_visit_all(result, barg, args...)([&](auto output, auto binput, auto... inputs) {
auto* outp = device_cast(output.data()); using type = typename decltype(output)::value_type;
const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal;
const std::size_t n = output.size();
launch(stream, nglobal, nlocal)([=](auto idx) __device__ { launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPHX_DEVICE_SHARED type buffer[2048]; MIGRAPHX_DEVICE_SHARED type buffer[2048];
// Load bias into LDS // Load bias into LDS
for(size_t i = idx.local; i < bdim_len; i += nlocal) for(size_t i = idx.local; i < bdim_len; i += nlocal)
{ {
buffer[i] = yp[i]; buffer[i] = binput.data()[i];
} }
__syncthreads(); __syncthreads();
// Process the data // Process the data
for(size_t i = idx.global; i < n; i += nglobal) for(size_t i = idx.global; i < nelements; i += nglobal)
{ {
auto bidx = (i % bdim_next_stride) / bdim_stride; auto bidx = (i % bdim_next_stride) / bdim_stride;
auto b = buffer[bidx]; auto b = buffer[bidx];
type x = xp[i]; output.data()[i] = f(inputs.data()[i]..., b);
outp[i] = f(x, b);
} }
}); });
}); });
...@@ -265,15 +121,14 @@ void binary_broadcast_impl( ...@@ -265,15 +121,14 @@ void binary_broadcast_impl(
template <class F, class... Arguments> template <class F, class... Arguments>
void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments... args) void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments... args)
{ {
// assert(x.get_shape().elements() == y.get_shape().elements());
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) { visit_all(result, args...)([&](auto output, auto... inputs) {
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>; using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
const std::size_t vec_size = 4; const std::size_t vec_size = 4;
auto data = pack_vec4(device_cast(inputs.data())...); auto data = pack_vec<4>(device_cast(inputs.data())...);
auto* outp = as_vec4(device_cast(output.data())); auto* outp = as_vec<4>(device_cast(output.data()));
gs_launch(stream, output_shape.elements() / vec_size)([=](auto i) { gs_launch(stream, output_shape.elements() / vec_size)([=](auto i) {
vec4<type> out = outp[i]; vec<type, 4> out = outp[i];
data( data(
[&](auto... xs) { [&](auto... xs) {
for(std::size_t j = 0; j < vec_size; j++) for(std::size_t j = 0; j < vec_size; j++)
...@@ -290,13 +145,9 @@ void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments. ...@@ -290,13 +145,9 @@ void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments.
template <class F, class... Arguments> template <class F, class... Arguments>
void nary_standard_impl(hipStream_t stream, F f, argument result, Arguments... args) void nary_standard_impl(hipStream_t stream, F f, argument result, Arguments... args)
{ {
// assert(x.get_shape().elements() == y.get_shape().elements()); std::size_t nelements = result.get_shape().elements();
const auto& output_shape = result.get_shape(); hip_pointer_visit_all(result, args...)([&](auto output, auto... inputs) {
visit_all(result, args...)([&](auto output, auto... inputs) { gs_launch(stream, nelements)([=](auto i) { output[i] = f(inputs[i]...); });
auto data = pack(device_cast(inputs.data())...);
auto* outp = device_cast(output.data());
gs_launch(stream, output_shape.elements())(
[=](auto i) { data([&](auto... xps) { outp[i] = f(xps[i]...); }); });
}); });
} }
...@@ -313,12 +164,6 @@ void nary_impl(hipStream_t stream, F f, argument result, Arguments... args) ...@@ -313,12 +164,6 @@ void nary_impl(hipStream_t stream, F f, argument result, Arguments... args)
nary_nonstandard_impl(stream, f, result, args...); nary_nonstandard_impl(stream, f, result, args...);
} }
template <class F>
void nary_impl(hipStream_t stream, F f, argument result)
{
nary_standard_impl(stream, f, result);
}
template <class... Arguments> template <class... Arguments>
auto nary_nonstandard(hipStream_t stream, argument result, Arguments... args) auto nary_nonstandard(hipStream_t stream, argument result, Arguments... args)
{ {
...@@ -332,71 +177,50 @@ auto nary_standard(hipStream_t stream, argument result, Arguments... args) ...@@ -332,71 +177,50 @@ auto nary_standard(hipStream_t stream, argument result, Arguments... args)
} }
template <class... Arguments> template <class... Arguments>
auto nary(hipStream_t stream, argument result, Arguments... args) auto nary(hipStream_t stream, argument result)
{ {
return [=](auto f) { nary_impl(stream, f, result, args...); }; return [=](auto f) { nary_standard_impl(stream, f, result); };
} }
inline auto template <class... Arguments>
nary(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2) auto nary(hipStream_t stream, argument result, Arguments... args)
{ {
return [=](auto f) {
// TODO: Check result and arg1 shape is the same
if(arg1.get_shape().standard() and arg2.get_shape().broadcasted() and
not arg2.get_shape().scalar())
{
auto not_zero = [](auto x) { return x != 0; };
const auto& strides = arg2.get_shape().strides();
auto b_it = std::find_if(strides.begin(), strides.end(), not_zero);
auto b_idx = std::distance(strides.begin(), b_it);
auto b_len = result.get_shape().lens()[b_idx];
auto b_stride = result.get_shape().strides()[b_idx];
assert(arg2.get_shape().lens()[b_idx] == b_len);
if(b_len <= 2048 and std::none_of(std::next(b_it), strides.end(), not_zero))
{
const bool divisible_by_4 = (b_len % 4 == 0) and (b_stride % 4 == 0) and
(arg1.get_shape().elements() % 4 == 0);
if(divisible_by_4)
binary_broadcast_vec_impl(stream, f, result, arg1, arg2);
else
binary_broadcast_impl(stream, f, result, arg1, arg2);
return;
}
}
nary_impl(stream, f, result, arg1, arg2);
};
}
inline auto nary(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3)
{
return [=](auto f) { return [=](auto f) {
// TODO: Check result and arg1 shape is the same auto barg = back_args(args...);
if(arg1.get_shape().standard() and arg2.get_shape().standard() and bool fallback = pop_back_args(args...)([&](auto&&... args2) {
arg3.get_shape().broadcasted()) auto bshape = barg.get_shape();
{ const bool standard =
auto not_zero = [](auto x) { return x != 0; }; all_of({args2.get_shape()...}, [](const shape& s) { return s.standard(); });
const auto& strides = arg3.get_shape().strides(); const bool same_shapes = all_of(
auto b_it = std::find_if(strides.begin(), strides.end(), not_zero); {args2.get_shape()...}, [&](const shape& s) { return s == result.get_shape(); });
auto b_idx = std::distance(strides.begin(), b_it); // TODO: Check result and args shape is the same
auto b_len = result.get_shape().lens()[b_idx]; if(standard and same_shapes and bshape.broadcasted() and not bshape.scalar())
auto b_stride = result.get_shape().strides()[b_idx];
assert(arg3.get_shape().lens()[b_idx] == b_len);
if(b_len <= 2048 and std::none_of(std::next(b_it), strides.end(), not_zero))
{ {
const bool divisible_by_4 = (b_len % 4 == 0) and (b_stride % 4 == 0) and auto not_zero = [](auto x) { return x != 0; };
(arg1.get_shape().elements() % 4 == 0); const auto& strides = bshape.strides();
if(divisible_by_4) auto b_it = std::find_if(strides.begin(), strides.end(), not_zero);
trinary_broadcast_vec_impl(stream, f, result, arg1, arg2, arg3); auto b_idx = std::distance(strides.begin(), b_it);
else auto b_len = result.get_shape().lens()[b_idx];
trinary_broadcast_impl(stream, f, result, arg1, arg2, arg3); auto b_stride = result.get_shape().strides()[b_idx];
return; assert(bshape.lens()[b_idx] == b_len);
if(b_len <= 2048 and std::none_of(std::next(b_it), strides.end(), not_zero))
{
const bool divisible_by_4 =
(b_len % 4 == 0) and (b_stride % 4 == 0) and
(front_args(args...).get_shape().elements() % 4 == 0);
if(divisible_by_4)
nary_broadcast_vec_impl(stream, f, result, barg, args2...);
else
nary_broadcast_impl(stream, f, result, barg, args2...);
return false;
}
} }
} return true;
nary_impl(stream, f, result, arg1, arg2, arg3); });
if(fallback)
nary_impl(stream, f, result, args...);
}; };
} }
......
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_SHAPE_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_SHAPE_HPP
#include <migraphx/gpu/device/array.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
template <std::size_t N>
struct hip_shape
{
using hip_index = hip_array<std::size_t, N>;
hip_array<std::size_t, N> lens = {};
hip_array<std::size_t, N> strides = {};
bool standard = false;
__device__ __host__ hip_shape() = default;
hip_shape(const shape& s) : standard(s.standard())
{
assert(s.lens().size() == N);
assert(s.strides().size() == N);
std::copy(s.lens().begin(), s.lens().end(), lens.begin());
std::copy(s.strides().begin(), s.strides().end(), strides.begin());
}
MIGRAPHX_DEVICE_CONSTEXPR std::size_t elements() const { return lens.product(); }
MIGRAPHX_DEVICE_CONSTEXPR std::size_t index(hip_index x) const { return x.dot(strides); }
MIGRAPHX_DEVICE_CONSTEXPR std::size_t index(std::initializer_list<std::size_t> x) const
{
std::size_t idx = 0;
for(std::size_t i = 0; i < x.size(); i++)
idx += *(x.begin() + i) * strides[i];
return idx;
}
MIGRAPHX_DEVICE_CONSTEXPR std::size_t index(std::size_t i) const
{
if(this->standard)
return i;
else
{
const std::size_t rank = this->lens.size();
std::size_t s = 1;
std::size_t result = 0;
for(std::size_t j = 0; j < this->lens.size(); j++)
{
const std::size_t k = rank - j - 1;
const std::size_t stride = this->strides[k];
const std::size_t len = this->lens[k];
const std::size_t slen = s * len;
const std::size_t idx = (i % slen) / s;
result += stride * idx;
s = slen;
}
return result;
}
}
MIGRAPHX_DEVICE_CONSTEXPR hip_index multi(std::size_t idx) const
{
hip_index result;
std::size_t tidx = idx;
for(std::size_t is = 0; is < result.size(); is++)
{
result[is] = tidx / strides[is];
tidx = tidx % strides[is];
}
return result;
}
};
template <std::size_t N>
hip_shape<N> make_hip_shape(const shape& x)
{
return x;
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_DEAVICE_TENSOR_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_DEAVICE_TENSOR_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEAVICE_TENSOR_HPP #define MIGRAPHX_GUARD_RTGLIB_DEAVICE_TENSOR_HPP
#include <hip/hip_runtime.h> #include <migraphx/gpu/device/visit.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
template <class F> template <std::size_t NDim>
void visit_tensor_size(std::size_t n, F f) using hip_tensor_index = hip_array<std::size_t, NDim>;
{
switch(n)
{
case 1:
{
f(std::integral_constant<std::size_t, 1>{});
break;
}
case 2:
{
f(std::integral_constant<std::size_t, 2>{});
break;
}
case 3:
{
f(std::integral_constant<std::size_t, 3>{});
break;
}
case 4:
{
f(std::integral_constant<std::size_t, 4>{});
break;
}
case 5:
{
f(std::integral_constant<std::size_t, 5>{});
break;
}
default: throw std::runtime_error("Unknown tensor size");
}
}
template <size_t NDim>
struct hip_index
{
size_t d[NDim];
__device__ __host__ size_t& operator[](size_t i) { return d[i]; }
__device__ __host__ size_t operator[](size_t i) const { return d[i]; }
};
template <size_t NDim> template <std::size_t NDim>
struct hip_tensor_descriptor struct hip_tensor_descriptor
{ {
__device__ __host__ hip_tensor_descriptor() = default; __device__ __host__ hip_tensor_descriptor() = default;
...@@ -63,26 +22,26 @@ struct hip_tensor_descriptor ...@@ -63,26 +22,26 @@ struct hip_tensor_descriptor
std::copy(s.strides().begin(), s.strides().end(), strides); std::copy(s.strides().begin(), s.strides().end(), strides);
} }
__device__ __host__ hip_index<NDim> multi(size_t idx) const __device__ __host__ hip_tensor_index<NDim> multi(std::size_t idx) const
{ {
hip_index<NDim> result{}; hip_tensor_index<NDim> result{};
size_t tidx = idx; std::size_t tidx = idx;
for(size_t is = 0; is < NDim; is++) for(std::size_t is = 0; is < NDim; is++)
{ {
result[is] = tidx / strides[is]; result[is] = tidx / strides[is];
tidx = tidx % strides[is]; tidx = tidx % strides[is];
} }
return result; return result;
} }
__device__ __host__ size_t linear(hip_index<NDim> s) const __device__ __host__ std::size_t linear(hip_tensor_index<NDim> s) const
{ {
size_t idx = 0; std::size_t idx = 0;
for(size_t i = 0; i < NDim; i++) for(std::size_t i = 0; i < NDim; i++)
idx += s[i] * strides[i]; idx += s[i] * strides[i];
return idx; return idx;
} }
size_t lens[NDim] = {}; std::size_t lens[NDim] = {};
size_t strides[NDim] = {}; std::size_t strides[NDim] = {};
}; };
} // namespace device } // namespace device
......
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_TENSOR_VIEW_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_TENSOR_VIEW_HPP
#include <migraphx/gpu/device/shape.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
template <class T, std::size_t N>
struct hip_tensor_view
{
using value_type = T;
using hip_index = typename hip_shape<N>::hip_index;
__device__ __host__ hip_tensor_view() = default;
__host__ hip_tensor_view(tensor_view<T> x) : d(x.data()), s(x.get_shape()) {}
__host__ hip_tensor_view(T* x, const shape& ss) : d(x), s(ss) {}
MIGRAPHX_DEVICE_CONSTEXPR const hip_shape<N>& get_shape() const { return s; }
MIGRAPHX_DEVICE_CONSTEXPR std::size_t size() const { return s.elements(); }
MIGRAPHX_DEVICE_CONSTEXPR value_type* data() const { return d; }
template <class U>
MIGRAPHX_DEVICE_CONSTEXPR value_type& operator[](U i) const
{
return d[s.index(i)];
}
MIGRAPHX_DEVICE_CONSTEXPR value_type* begin() const { return d; }
MIGRAPHX_DEVICE_CONSTEXPR value_type* end() const { return d + size(); }
private:
value_type* d = nullptr;
hip_shape<N> s{};
};
template <std::size_t N, class T>
hip_tensor_view<T, N> make_hip_view(const shape& s, T* x)
{
return {x, s};
}
template <std::size_t N, class T>
hip_tensor_view<T, N> make_hip_view(tensor_view<T> x)
{
return {x};
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -8,14 +8,45 @@ ...@@ -8,14 +8,45 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_GPU_DEVICE_TYPES_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_GPU_DEVICE_TYPES_HPP
#define MIGRAPHX_GUARD_RTGLIB_GPU_DEVICE_TYPES_HPP #define MIGRAPHX_GUARD_RTGLIB_GPU_DEVICE_TYPES_HPP
#include <hip/hip_runtime.h>
#include <migraphx/half.hpp> #include <migraphx/half.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/tensor_view.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
#define MIGRAPHX_DEVICE_CONSTEXPR constexpr __device__ __host__ // NOLINT
template <class T, std::size_t N>
using vec = T __attribute__((ext_vector_type(N)));
template <std::size_t N, class T>
__device__ __host__ T* as_pointer(vec<T, N>* x)
{
return reinterpret_cast<T*>(x);
}
template <std::size_t N, class T>
__device__ __host__ vec<T, N>* as_vec(T* x)
{
return reinterpret_cast<vec<T, N>*>(x);
}
template <std::size_t N, class T>
tensor_view<vec<T, N>> as_vec(tensor_view<T> x)
{
return {x.get_shape(), as_vec<N>(x.data())};
}
template <std::size_t N, class... Ts>
auto pack_vec(Ts... xs)
{
return [=](auto f, std::size_t n) { return f(as_vec<N>(xs)[n]...); };
}
using gpu_half = __fp16; using gpu_half = __fp16;
namespace detail { namespace detail {
...@@ -25,6 +56,12 @@ struct device_type ...@@ -25,6 +56,12 @@ struct device_type
using type = T; using type = T;
}; };
template <class T, std::size_t N>
struct device_type<vec<T, N>>
{
using type = vec<typename device_type<T>::type, N>;
};
template <> template <>
struct device_type<half> struct device_type<half>
{ {
...@@ -38,7 +75,7 @@ struct host_type ...@@ -38,7 +75,7 @@ struct host_type
}; };
template <> template <>
struct device_type<gpu_half> struct host_type<gpu_half>
{ {
using type = half; using type = half;
}; };
...@@ -75,6 +112,12 @@ device_type<T>* device_cast(T* x) ...@@ -75,6 +112,12 @@ device_type<T>* device_cast(T* x)
return reinterpret_cast<device_type<T>*>(x); return reinterpret_cast<device_type<T>*>(x);
} }
template <class T>
tensor_view<device_type<T>> device_cast(tensor_view<T> x)
{
return {x.get_shape(), reinterpret_cast<device_type<T>*>(x.data())};
}
template <class T> template <class T>
__device__ __host__ T to_hip_type(T x) __device__ __host__ T to_hip_type(T x)
{ {
......
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_VECTOR_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_VECTOR_HPP
#include <migraphx/gpu/device/types.hpp>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
template <class T, std::size_t N>
struct hip_vector
{
MIGRAPHX_DEVICE_CONSTEXPR hip_vector() = default;
MIGRAPHX_DEVICE_CONSTEXPR hip_vector(std::size_t s) : len(s) {}
template <class Iterator>
__device__ __host__ hip_vector(Iterator start, Iterator last)
{
auto it = std::copy(start, last, d);
len = std::distance(d, it);
}
__device__ __host__ hip_vector(std::initializer_list<T> x)
{
std::copy(x.begin(), x.end(), d);
len = x.size();
}
MIGRAPHX_DEVICE_CONSTEXPR T& operator[](std::size_t i) { return d[i]; }
MIGRAPHX_DEVICE_CONSTEXPR const T& operator[](std::size_t i) const { return d[i]; }
MIGRAPHX_DEVICE_CONSTEXPR T& front() { return d[0]; }
MIGRAPHX_DEVICE_CONSTEXPR const T& front() const { return d[0]; }
MIGRAPHX_DEVICE_CONSTEXPR T& back() { return d[size() - 1]; }
MIGRAPHX_DEVICE_CONSTEXPR const T& back() const { return d[size() - 1]; }
MIGRAPHX_DEVICE_CONSTEXPR T* data() { return d; }
MIGRAPHX_DEVICE_CONSTEXPR const T* data() const { return d; }
MIGRAPHX_DEVICE_CONSTEXPR std::size_t size() const { return len; }
MIGRAPHX_DEVICE_CONSTEXPR T* begin() { return d; }
MIGRAPHX_DEVICE_CONSTEXPR const T* begin() const { return d; }
MIGRAPHX_DEVICE_CONSTEXPR T* end() { return d + size(); }
MIGRAPHX_DEVICE_CONSTEXPR const T* end() const { return d + size(); }
template <class U>
MIGRAPHX_DEVICE_CONSTEXPR void push_back(U&& x)
{
d[len] = static_cast<U&&>(x);
len++;
}
private:
T d[N] = {};
std::size_t len = 0;
};
template <std::size_t N, class T>
hip_vector<T, N> to_hip_vector(const std::vector<T>& x)
{
hip_vector<T, N> result(x.size());
std::copy(x.begin(), x.end(), result.begin());
return result;
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_VISIT_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_VISIT_HPP
#include <migraphx/gpu/device/tensor_view.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
template <class F>
void visit_tensor_size(std::size_t n, F f)
{
switch(n)
{
case 1:
{
f(std::integral_constant<std::size_t, 1>{});
break;
}
case 2:
{
f(std::integral_constant<std::size_t, 2>{});
break;
}
case 3:
{
f(std::integral_constant<std::size_t, 3>{});
break;
}
case 4:
{
f(std::integral_constant<std::size_t, 4>{});
break;
}
case 5:
{
f(std::integral_constant<std::size_t, 5>{});
break;
}
default: throw std::runtime_error("Unknown tensor size");
}
}
inline shape get_shape(const shape& x) { return x; }
template <class T>
auto get_shape(const T& x) -> decltype(x.get_shape())
{
return x.get_shape();
}
template <class V, class F, class... Ts>
void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
{
std::initializer_list<migraphx::shape::type_t> types = {get_shape(xs).type()...};
if(!std::all_of(
types.begin(), types.end(), [&](migraphx::shape::type_t t) { return t == s.type(); }))
MIGRAPHX_THROW("Types must be the same");
std::initializer_list<std::size_t> ranks = {get_shape(xs).lens().size()...};
if(!std::all_of(
ranks.begin(), ranks.end(), [&](std::size_t r) { return r == s.lens().size(); }))
MIGRAPHX_THROW("Ranks must be the same");
visit_tensor_size(s.lens().size(),
[&](auto ndim) { s.visit_type([&](auto as) { v(f(xs, ndim, as)...); }); });
}
template <class V, class F, class... Ts>
void hip_visit_views_impl(const shape& s, F f, V&& v, Ts&&... xs)
{
std::initializer_list<std::size_t> ranks = {get_shape(xs).lens().size()...};
if(!std::all_of(
ranks.begin(), ranks.end(), [&](std::size_t r) { return r == s.lens().size(); }))
MIGRAPHX_THROW("Ranks must be the same");
visit_tensor_size(s.lens().size(), [&](auto ndim) { v(f(xs, ndim)...); });
}
template <class F>
struct hip_convert
{
F f;
template <class RawData, class N, class As>
auto operator()(RawData x, N ndim, As as) const
-> decltype(make_hip_view<ndim>(x.get_shape(), f(as.from(x.data()))))
{
return make_hip_view<ndim>(x.get_shape(), f(as.from(x.data())));
}
template <class N, class As>
auto operator()(const shape& s, N ndim, As) const
{
return make_hip_shape<ndim>(s);
}
};
template <class F>
hip_convert<F> make_hip_convert(F f)
{
return {f};
}
template <class F>
struct hip_convert_view
{
F f;
template <class T, class N>
auto operator()(tensor_view<T> x, N ndim) const
{
return make_hip_view<ndim>(f(x));
}
template <class N>
auto operator()(const shape& s, N ndim) const
{
return make_hip_shape<ndim>(s);
}
};
template <class F>
hip_convert_view<F> make_hip_convert_view(F f)
{
return {f};
}
template <class T, class... Ts>
auto hip_visit_all(T&& x, Ts&&... xs)
{
return [&](auto f) {
hip_visit_all_impl(
get_shape(x), make_hip_convert([](auto* p) { return device_cast(p); }), f, x, xs...);
};
}
template <std::size_t N, class T, class... Ts>
auto hip_vec_visit_all(T&& x, Ts&&... xs)
{
return [&](auto f) {
hip_visit_all_impl(get_shape(x),
make_hip_convert([](auto* p) { return as_vec<N>(device_cast(p)); }),
f,
x,
xs...);
};
}
template <class T, class... Ts>
auto hip_pointer_visit_all(T&& x, Ts&&... xs)
{
return [&](auto f) { visit_all(x, xs...)([&](auto... vs) { f(device_cast(vs.data())...); }); };
}
template <class T, class... Ts>
auto hip_visit_views(T&& x, Ts&&... xs)
{
return [&](auto f) {
hip_visit_views_impl(get_shape(x),
make_hip_convert_view([](auto v) { return device_cast(v); }),
f,
x,
xs...);
};
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -12,7 +12,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -12,7 +12,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void logsoftmax(hipStream_t stream, const argument& result, const argument& arg, int axis) void logsoftmax(hipStream_t stream, argument result, argument arg, int axis)
{ {
auto lens = result.get_shape().lens(); auto lens = result.get_shape().lens();
...@@ -21,82 +21,75 @@ void logsoftmax(hipStream_t stream, const argument& result, const argument& arg, ...@@ -21,82 +21,75 @@ void logsoftmax(hipStream_t stream, const argument& result, const argument& arg,
batch_lens[axis] = 1; batch_lens[axis] = 1;
migraphx::shape batch_shape{result.get_shape().type(), batch_lens}; migraphx::shape batch_shape{result.get_shape().type(), batch_lens};
visit_all(result, arg)([&](auto output, auto input) { hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) {
const auto* input_ptr = device_cast(input.data()); // use one block for items in one batch.
auto* output_ptr = device_cast(output.data()); const size_t max_block_size = 1024;
visit_tensor_size(batch_shape.lens().size(), [&](auto n_dim) { size_t block_size = 1;
hip_tensor_descriptor<n_dim> desc_batch(batch_shape); while(block_size < max_block_size and block_size < batch_item_num)
hip_tensor_descriptor<n_dim> desc_data(result.get_shape()); {
block_size *= 2;
// use one block for items in one batch. }
const size_t max_block_size = 1024;
size_t block_size = 1; launch(
while(block_size < max_block_size and block_size < batch_item_num) stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ {
size_t thr_idx = idx.local;
size_t blk_idx = idx.group;
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
MIGRAPHX_DEVICE_SHARED type lds_data[max_block_size + 1];
auto batch_idx = batch.multi(blk_idx);
auto data_idx = batch_idx;
// load data to lds and compute the batch max
size_t remaining_item_num = batch_item_num;
size_t round_item_num = (batch_item_num + block_size - 1) / block_size * block_size;
lds_data[block_size] = input[0];
for(size_t i = thr_idx; i < round_item_num; i += block_size)
{ {
block_size *= 2; if(i < batch_item_num)
}
launch(
stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ {
size_t thr_idx = idx.local;
size_t blk_idx = idx.group;
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
MIGRAPHX_DEVICE_SHARED type lds_data[max_block_size + 1];
auto batch_idx = desc_batch.multi(blk_idx);
auto data_idx = batch_idx;
// load data to lds and compute the batch max
size_t remaining_item_num = batch_item_num;
size_t round_item_num = (batch_item_num + block_size - 1) / block_size * block_size;
lds_data[block_size] = input_ptr[0];
for(size_t i = thr_idx; i < round_item_num; i += block_size)
{ {
if(i < batch_item_num) data_idx[axis] = i;
{ lds_data[thr_idx] = input[desc_data.linear(data_idx)];
data_idx[axis] = i;
lds_data[thr_idx] = input_ptr[desc_data.linear(data_idx)];
}
__syncthreads();
auto item_num =
(remaining_item_num > block_size) ? block_size : remaining_item_num;
reduce_max(lds_data, block_size, thr_idx, item_num);
remaining_item_num -= block_size;
} }
auto batch_max = lds_data[block_size];
__syncthreads(); __syncthreads();
lds_data[block_size] = 0; auto item_num =
remaining_item_num = batch_item_num; (remaining_item_num > block_size) ? block_size : remaining_item_num;
for(size_t i = thr_idx; i < round_item_num; i += block_size) reduce_max(lds_data, block_size, thr_idx, item_num);
{
if(i < batch_item_num)
{
data_idx[axis] = i;
lds_data[thr_idx] = input_ptr[desc_data.linear(data_idx)] - batch_max;
lds_data[thr_idx] = ::exp(to_hip_type(lds_data[thr_idx]));
}
__syncthreads();
auto item_num =
(remaining_item_num > block_size) ? block_size : remaining_item_num;
reduce_sum(lds_data, block_size, thr_idx, item_num);
remaining_item_num -= block_size; remaining_item_num -= block_size;
} }
auto log_batch_sum = ::log(to_hip_type(lds_data[block_size])) + batch_max; auto batch_max = lds_data[block_size];
__syncthreads();
for(size_t i = thr_idx; i < batch_item_num; i += block_size) lds_data[block_size] = 0;
remaining_item_num = batch_item_num;
for(size_t i = thr_idx; i < round_item_num; i += block_size)
{
if(i < batch_item_num)
{ {
data_idx[axis] = i; data_idx[axis] = i;
size_t index = desc_data.linear(data_idx); lds_data[thr_idx] = input[desc_data.linear(data_idx)] - batch_max;
output_ptr[index] = input_ptr[index] - log_batch_sum; lds_data[thr_idx] = ::exp(to_hip_type(lds_data[thr_idx]));
} }
});
__syncthreads();
auto item_num =
(remaining_item_num > block_size) ? block_size : remaining_item_num;
reduce_sum(lds_data, block_size, thr_idx, item_num);
remaining_item_num -= block_size;
}
auto log_batch_sum = ::log(to_hip_type(lds_data[block_size])) + batch_max;
for(size_t i = thr_idx; i < batch_item_num; i += block_size)
{
data_idx[axis] = i;
size_t index = desc_data.linear(data_idx);
output[index] = input[index] - log_batch_sum;
}
}); });
}); });
} }
......
...@@ -15,33 +15,26 @@ argument ...@@ -15,33 +15,26 @@ argument
pad(hipStream_t stream, argument result, argument arg1, float value, std::vector<std::int64_t> pads) pad(hipStream_t stream, argument result, argument arg1, float value, std::vector<std::int64_t> pads)
{ {
std::size_t nelements = arg1.get_shape().elements(); std::size_t nelements = arg1.get_shape().elements();
visit_all(result)([&](auto output) { hip_visit_all(result, arg1)([&](auto output, auto input) {
auto* outptr = device_cast(output.data()); using type = typename decltype(output)::value_type;
using type = typename decltype(output)::value_type; using hip_index = typename decltype(output)::hip_index;
device_type<type> device_val = value; type device_val = value;
if(float_equal(value, std::numeric_limits<float>::lowest())) if(float_equal(value, std::numeric_limits<float>::lowest()))
{ {
device_val = device_cast(std::numeric_limits<type>::lowest()); device_val = device_cast(std::numeric_limits<type>::lowest());
} }
gs_launch(stream, result.get_shape().elements())([=](auto i) { outptr[i] = device_val; }); gs_launch(stream,
}); result.get_shape().elements())([=](auto i) { output.data()[i] = device_val; });
visit_all(result, arg1)([&](auto output, auto input) { hip_index offsets;
visit_tensor_size(result.get_shape().lens().size(), [&](auto ndim) { std::copy(pads.begin(), pads.begin() + offsets.size(), offsets.begin());
std::size_t offsets[ndim]; gs_launch(stream, nelements)([=](auto i) {
std::copy(pads.begin(), pads.begin() + ndim, offsets); auto idx = input.get_shape().multi(i);
auto* outptr = output.data(); for(std::size_t j = 0; j < offsets.size(); j++)
const auto* inptr = input.data(); {
hip_tensor_descriptor<ndim> desc_input(input.get_shape()); idx[j] += offsets[j];
hip_tensor_descriptor<ndim> desc_output(output.get_shape()); }
gs_launch(stream, nelements)([=](auto i) { output[idx] = input.data()[i];
auto idx = desc_input.multi(i);
for(std::size_t j = 0; j < ndim; j++)
{
idx[j] += offsets[j];
}
outptr[desc_output.linear(idx)] = inptr[i];
});
}); });
}); });
return result; return result;
......
...@@ -13,7 +13,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -13,7 +13,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void softmax(hipStream_t stream, const argument& result, const argument& arg, int axis) void softmax(hipStream_t stream, argument result, argument arg, int axis)
{ {
auto lens = result.get_shape().lens(); auto lens = result.get_shape().lens();
auto batch_lens = lens; auto batch_lens = lens;
...@@ -21,83 +21,76 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in ...@@ -21,83 +21,76 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
batch_lens[axis] = 1; batch_lens[axis] = 1;
migraphx::shape batch_shape{result.get_shape().type(), batch_lens}; migraphx::shape batch_shape{result.get_shape().type(), batch_lens};
visit_all(result, arg)([&](auto output, auto input) { hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) {
const auto* input_ptr = device_cast(input.data()); // use one block for items in one batch.
auto* output_ptr = device_cast(output.data()); const size_t max_block_size = 1024;
visit_tensor_size(batch_shape.lens().size(), [&](auto n_dim) { size_t block_size = 1;
hip_tensor_descriptor<n_dim> desc_batch(batch_shape); while(block_size < max_block_size and block_size < batch_item_num)
hip_tensor_descriptor<n_dim> desc_data(result.get_shape()); {
block_size *= 2;
// use one block for items in one batch. }
const size_t max_block_size = 1024;
size_t block_size = 1; launch(
while(block_size < max_block_size and block_size < batch_item_num) stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ {
size_t thr_idx = idx.local;
size_t blk_idx = idx.group;
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
MIGRAPHX_DEVICE_SHARED type lds_data[max_block_size + 1];
auto batch_idx = batch.multi(blk_idx);
auto data_idx = batch_idx;
// load data to lds and compute the batch max
size_t remaining_item_num = batch_item_num;
size_t round_item_num = (batch_item_num + block_size - 1) / block_size * block_size;
lds_data[block_size] = input[0];
for(size_t i = thr_idx; i < round_item_num; i += block_size)
{ {
block_size *= 2; if(i < batch_item_num)
}
launch(
stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ {
size_t thr_idx = idx.local;
size_t blk_idx = idx.group;
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
MIGRAPHX_DEVICE_SHARED type lds_data[max_block_size + 1];
auto batch_idx = desc_batch.multi(blk_idx);
auto data_idx = batch_idx;
// load data to lds and compute the batch max
size_t remaining_item_num = batch_item_num;
size_t round_item_num = (batch_item_num + block_size - 1) / block_size * block_size;
lds_data[block_size] = input_ptr[0];
for(size_t i = thr_idx; i < round_item_num; i += block_size)
{ {
if(i < batch_item_num) data_idx[axis] = i;
{ lds_data[thr_idx] = input[desc_data.linear(data_idx)];
data_idx[axis] = i;
lds_data[thr_idx] = input_ptr[desc_data.linear(data_idx)];
}
__syncthreads();
auto item_num =
(remaining_item_num > block_size) ? block_size : remaining_item_num;
reduce_max(lds_data, block_size, thr_idx, item_num);
remaining_item_num -= block_size;
} }
auto batch_max = lds_data[block_size];
__syncthreads(); __syncthreads();
lds_data[block_size] = 0; auto item_num =
remaining_item_num = batch_item_num; (remaining_item_num > block_size) ? block_size : remaining_item_num;
for(size_t i = thr_idx; i < round_item_num; i += block_size) reduce_max(lds_data, block_size, thr_idx, item_num);
{
if(i < batch_item_num)
{
data_idx[axis] = i;
lds_data[thr_idx] = input_ptr[desc_data.linear(data_idx)] - batch_max;
lds_data[thr_idx] = ::exp(to_hip_type(lds_data[thr_idx]));
}
__syncthreads(); remaining_item_num -= block_size;
}
auto item_num = auto batch_max = lds_data[block_size];
(remaining_item_num > block_size) ? block_size : remaining_item_num; __syncthreads();
reduce_sum(lds_data, block_size, thr_idx, item_num);
remaining_item_num -= block_size; lds_data[block_size] = 0;
} remaining_item_num = batch_item_num;
auto batch_sum = lds_data[block_size]; for(size_t i = thr_idx; i < round_item_num; i += block_size)
{
for(size_t i = thr_idx; i < batch_item_num; i += block_size) if(i < batch_item_num)
{ {
data_idx[axis] = i; data_idx[axis] = i;
size_t index = desc_data.linear(data_idx); lds_data[thr_idx] = input[desc_data.linear(data_idx)] - batch_max;
auto val = input_ptr[index] - batch_max; lds_data[thr_idx] = ::exp(to_hip_type(lds_data[thr_idx]));
output_ptr[index] = ::exp(to_hip_type(val)) / batch_sum;
} }
});
__syncthreads();
auto item_num =
(remaining_item_num > block_size) ? block_size : remaining_item_num;
reduce_sum(lds_data, block_size, thr_idx, item_num);
remaining_item_num -= block_size;
}
auto batch_sum = lds_data[block_size];
for(size_t i = thr_idx; i < batch_item_num; i += block_size)
{
data_idx[axis] = i;
size_t index = desc_data.linear(data_idx);
auto val = input[index] - batch_max;
output[index] = ::exp(to_hip_type(val)) / batch_sum;
}
}); });
}); });
} }
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <migraphx/gpu/device/add_relu.hpp> #include <migraphx/gpu/device/add_relu.hpp>
#include <migraphx/gpu/device/add.hpp> #include <migraphx/gpu/device/add.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/array.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -122,13 +123,6 @@ MIGRAPHX_PRED_MATCHER(bias_shape, instruction_ref ins) ...@@ -122,13 +123,6 @@ MIGRAPHX_PRED_MATCHER(bias_shape, instruction_ref ins)
s.strides()[1] != 0 and s.strides()[2] == 0 and s.strides()[3] == 0; s.strides()[1] != 0 and s.strides()[2] == 0 and s.strides()[3] == 0;
} }
// TODO: Move to another header
template <class T, class... Ts>
std::array<T, sizeof...(Ts) + 1> make_array(T x, Ts... xs)
{
return {std::move(x), std::move(static_cast<T>(xs))...};
}
MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins) MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins)
{ {
if(ins->name() != "gpu::convolution") if(ins->name() != "gpu::convolution")
...@@ -408,8 +402,8 @@ void fuse_ops::apply(program& p) const ...@@ -408,8 +402,8 @@ void fuse_ops::apply(program& p) const
// clang-format off // clang-format off
match::find_matches(p, find_triadd{}); match::find_matches(p, find_triadd{});
match::find_matches(p, match::find_matches(p,
find_conv_bias_relu{ctx}, // find_conv_bias_relu{ctx},
find_conv_bias{ctx}, // find_conv_bias{ctx},
find_add_relu{} find_add_relu{}
); );
// clang-format on // clang-format on
......
...@@ -12,11 +12,9 @@ shape hip_gather::compute_shape(std::vector<shape> inputs) const ...@@ -12,11 +12,9 @@ shape hip_gather::compute_shape(std::vector<shape> inputs) const
return op.compute_shape(inputs); return op.compute_shape(inputs);
} }
argument hip_gather::compute(context& ctx, argument hip_gather::compute(context& ctx, const shape&, const std::vector<argument>& args) const
const shape& output_shape,
const std::vector<argument>& args) const
{ {
return device::gather(ctx.get_stream().get(), output_shape, args, op.axis); return device::gather(ctx.get_stream().get(), args.back(), args[0], args[1], op.axis);
} }
} // namespace gpu } // namespace gpu
......
...@@ -10,10 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,10 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
argument gather(hipStream_t stream, argument gather(hipStream_t stream, argument result, argument arg1, argument arg2, int axis);
const migraphx::shape& output_shape,
std::vector<migraphx::argument> args,
int axis);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment