Commit 193a3b7c authored by Paul Fultz II's avatar Paul Fultz II Committed by mvermeulen
Browse files

Use 32-bit integers for index calculations on the gpu (#387)

* use 32bit integers for indices

* Formatting

* Update more index types

* Formatting
parent dc23d605
...@@ -13,13 +13,13 @@ namespace device { ...@@ -13,13 +13,13 @@ namespace device {
#define MIGRAPHX_DEVICE_ARRAY_OP(op, binary_op) \ #define MIGRAPHX_DEVICE_ARRAY_OP(op, binary_op) \
MIGRAPHX_DEVICE_CONSTEXPR hip_array& operator op(const hip_array& x) \ MIGRAPHX_DEVICE_CONSTEXPR hip_array& operator op(const hip_array& x) \
{ \ { \
for(std::size_t i = 0; i < N; i++) \ for(index_int i = 0; i < N; i++) \
d[i] op x[i]; \ d[i] op x[i]; \
return *this; \ return *this; \
} \ } \
MIGRAPHX_DEVICE_CONSTEXPR hip_array& operator op(const T& x) \ MIGRAPHX_DEVICE_CONSTEXPR hip_array& operator op(const T& x) \
{ \ { \
for(std::size_t i = 0; i < N; i++) \ for(index_int i = 0; i < N; i++) \
d[i] op x; \ d[i] op x; \
return *this; \ return *this; \
} \ } \
...@@ -36,12 +36,12 @@ namespace device { ...@@ -36,12 +36,12 @@ namespace device {
return x op y; \ return x op y; \
} }
template <class T, std::size_t N> template <class T, index_int N>
struct hip_array struct hip_array
{ {
T d[N]; T d[N];
MIGRAPHX_DEVICE_CONSTEXPR T& operator[](std::size_t i) { return d[i]; } MIGRAPHX_DEVICE_CONSTEXPR T& operator[](index_int i) { return d[i]; }
MIGRAPHX_DEVICE_CONSTEXPR const T& operator[](std::size_t i) const { return d[i]; } MIGRAPHX_DEVICE_CONSTEXPR const T& operator[](index_int i) const { return d[i]; }
MIGRAPHX_DEVICE_CONSTEXPR T& front() { return d[0]; } MIGRAPHX_DEVICE_CONSTEXPR T& front() { return d[0]; }
MIGRAPHX_DEVICE_CONSTEXPR const T& front() const { return d[0]; } MIGRAPHX_DEVICE_CONSTEXPR const T& front() const { return d[0]; }
...@@ -52,7 +52,7 @@ struct hip_array ...@@ -52,7 +52,7 @@ struct hip_array
MIGRAPHX_DEVICE_CONSTEXPR T* data() { return d; } MIGRAPHX_DEVICE_CONSTEXPR T* data() { return d; }
MIGRAPHX_DEVICE_CONSTEXPR const T* data() const { return d; } MIGRAPHX_DEVICE_CONSTEXPR const T* data() const { return d; }
MIGRAPHX_DEVICE_CONSTEXPR std::integral_constant<std::size_t, N> size() const { return {}; } MIGRAPHX_DEVICE_CONSTEXPR std::integral_constant<index_int, N> size() const { return {}; }
MIGRAPHX_DEVICE_CONSTEXPR T* begin() { return d; } MIGRAPHX_DEVICE_CONSTEXPR T* begin() { return d; }
MIGRAPHX_DEVICE_CONSTEXPR const T* begin() const { return d; } MIGRAPHX_DEVICE_CONSTEXPR const T* begin() const { return d; }
...@@ -63,7 +63,7 @@ struct hip_array ...@@ -63,7 +63,7 @@ struct hip_array
MIGRAPHX_DEVICE_CONSTEXPR T dot(const hip_array& x) const MIGRAPHX_DEVICE_CONSTEXPR T dot(const hip_array& x) const
{ {
T result = 0; T result = 0;
for(std::size_t i = 0; i < N; i++) for(index_int i = 0; i < N; i++)
result += x[i] * d[i]; result += x[i] * d[i];
return result; return result;
} }
...@@ -71,16 +71,16 @@ struct hip_array ...@@ -71,16 +71,16 @@ struct hip_array
MIGRAPHX_DEVICE_CONSTEXPR T product() const MIGRAPHX_DEVICE_CONSTEXPR T product() const
{ {
T result = 1; T result = 1;
for(std::size_t i = 0; i < N; i++) for(index_int i = 0; i < N; i++)
result *= d[i]; result *= d[i];
return result; return result;
} }
MIGRAPHX_DEVICE_CONSTEXPR T single(std::size_t width = 100) const MIGRAPHX_DEVICE_CONSTEXPR T single(index_int width = 100) const
{ {
T result = 0; T result = 0;
T a = 1; T a = 1;
for(std::size_t i = 0; i < N; i++) for(index_int i = 0; i < N; i++)
{ {
result += d[N - i - 1] * a; result += d[N - i - 1] * a;
a *= width; a *= width;
...@@ -98,7 +98,7 @@ struct hip_array ...@@ -98,7 +98,7 @@ struct hip_array
friend MIGRAPHX_DEVICE_CONSTEXPR bool operator==(const hip_array& x, const hip_array& y) friend MIGRAPHX_DEVICE_CONSTEXPR bool operator==(const hip_array& x, const hip_array& y)
{ {
for(std::size_t i = 0; i < N; i++) for(index_int i = 0; i < N; i++)
{ {
if(x[i] != y[i]) if(x[i] != y[i])
return false; return false;
...@@ -113,7 +113,7 @@ struct hip_array ...@@ -113,7 +113,7 @@ struct hip_array
// This uses the product order rather than lexical order // This uses the product order rather than lexical order
friend MIGRAPHX_DEVICE_CONSTEXPR bool operator<(const hip_array& x, const hip_array& y) friend MIGRAPHX_DEVICE_CONSTEXPR bool operator<(const hip_array& x, const hip_array& y)
{ {
for(std::size_t i = 0; i < N; i++) for(index_int i = 0; i < N; i++)
{ {
if(not(x[i] < y[i])) if(not(x[i] < y[i]))
return false; return false;
......
...@@ -8,33 +8,33 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -8,33 +8,33 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
constexpr const std::size_t fast_div_shift = 42; constexpr const uint64_t fast_div_shift = 42;
inline std::size_t encode_divisor(std::size_t divisor) inline uint64_t encode_divisor(uint64_t divisor)
{ {
if(divisor == 0) if(divisor == 0)
return 0; return 0;
auto p = std::size_t{1} << fast_div_shift; auto p = uint64_t{1} << fast_div_shift;
return (p + divisor - 1) / divisor; return (p + divisor - 1) / divisor;
} }
inline constexpr bool is_divisor_encodable(std::size_t i) inline constexpr bool is_divisor_encodable(uint64_t i)
{ {
return i < (std::size_t{1} << (fast_div_shift / 2)); return i < (uint64_t{1} << (fast_div_shift / 2));
} }
MIGRAPHX_DEVICE_CONSTEXPR std::size_t fast_div(std::size_t dividend, std::size_t encoded_divisor) MIGRAPHX_DEVICE_CONSTEXPR uint64_t fast_div(uint64_t dividend, uint64_t encoded_divisor)
{ {
return (dividend * encoded_divisor) >> fast_div_shift; return (dividend * encoded_divisor) >> fast_div_shift;
} }
MIGRAPHX_DEVICE_CONSTEXPR std::size_t MIGRAPHX_DEVICE_CONSTEXPR uint64_t remainder(uint64_t result, uint64_t dividend, uint64_t divisor)
remainder(std::size_t result, std::size_t dividend, std::size_t divisor)
{ {
return dividend - divisor * result; return dividend - divisor * result;
} }
MIGRAPHX_DEVICE_CONSTEXPR std::size_t MIGRAPHX_DEVICE_CONSTEXPR uint64_t fast_mod(uint64_t dividend,
fast_mod(std::size_t dividend, std::size_t divisor, std::size_t encoded_divisor) uint64_t divisor,
uint64_t encoded_divisor)
{ {
return remainder(fast_div(dividend, encoded_divisor), dividend, divisor); return remainder(fast_div(dividend, encoded_divisor), dividend, divisor);
} }
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/gpu/device/types.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -11,29 +12,29 @@ namespace device { ...@@ -11,29 +12,29 @@ namespace device {
struct index struct index
{ {
std::size_t global = 0; index_int global = 0;
std::size_t local = 0; index_int local = 0;
std::size_t group = 0; index_int group = 0;
__device__ std::size_t nglobal() const { return blockDim.x * gridDim.x; } // NOLINT __device__ index_int nglobal() const { return blockDim.x * gridDim.x; } // NOLINT
__device__ std::size_t nlocal() const { return blockDim.x; } // NOLINT __device__ index_int nlocal() const { return blockDim.x; } // NOLINT
template <class F> template <class F>
__device__ void global_stride(std::size_t n, F f) const __device__ void global_stride(index_int n, F f) const
{ {
const auto stride = nglobal(); const auto stride = nglobal();
for(std::size_t i = global; i < n; i += stride) for(index_int i = global; i < n; i += stride)
{ {
f(i); f(i);
} }
} }
template <class F> template <class F>
__device__ void local_stride(std::size_t n, F f) const __device__ void local_stride(index_int n, F f) const
{ {
const auto stride = nlocal(); const auto stride = nlocal();
for(std::size_t i = local; i < n; i += stride) for(index_int i = local; i < n; i += stride)
{ {
f(i); f(i);
} }
...@@ -47,7 +48,7 @@ __global__ void launcher(F f) ...@@ -47,7 +48,7 @@ __global__ void launcher(F f)
f(idx); f(idx);
} }
inline auto launch(hipStream_t stream, std::size_t global, std::size_t local) inline auto launch(hipStream_t stream, index_int global, index_int local)
{ {
return [=](auto f) { return [=](auto f) {
assert(local > 0); assert(local > 0);
...@@ -60,21 +61,21 @@ inline auto launch(hipStream_t stream, std::size_t global, std::size_t local) ...@@ -60,21 +61,21 @@ inline auto launch(hipStream_t stream, std::size_t global, std::size_t local)
} }
template <class F> template <class F>
__host__ __device__ auto gs_invoke(F&& f, std::size_t i, index idx) -> decltype(f(i, idx)) __host__ __device__ auto gs_invoke(F&& f, index_int i, index idx) -> decltype(f(i, idx))
{ {
return f(i, idx); return f(i, idx);
} }
template <class F> template <class F>
__host__ __device__ auto gs_invoke(F&& f, std::size_t i, index) -> decltype(f(i)) __host__ __device__ auto gs_invoke(F&& f, index_int i, index) -> decltype(f(i))
{ {
return f(i); return f(i);
} }
inline auto gs_launch(hipStream_t stream, std::size_t n, std::size_t local = 1024) inline auto gs_launch(hipStream_t stream, index_int n, index_int local = 1024)
{ {
std::size_t groups = (n + local - 1) / local; index_int groups = (n + local - 1) / local;
std::size_t nglobal = std::min<std::size_t>(256, groups) * local; index_int nglobal = std::min<index_int>(256, groups) * local;
return [=](auto f) { return [=](auto f) {
launch(stream, nglobal, local)( launch(stream, nglobal, local)(
......
...@@ -10,10 +10,10 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,10 +10,10 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
template <std::size_t N> template <index_int N>
struct multi_index struct multi_index
{ {
using hip_index = hip_array<std::size_t, N>; using hip_index = hip_array<index_int, N>;
hip_index id{}; hip_index id{};
hip_index stride{}; hip_index stride{};
...@@ -27,28 +27,28 @@ struct multi_index ...@@ -27,28 +27,28 @@ struct multi_index
} }
}; };
template <std::size_t N> template <index_int N>
MIGRAPHX_DEVICE_CONSTEXPR multi_index<N> MIGRAPHX_DEVICE_CONSTEXPR multi_index<N>
make_multi_index(const hip_shape<N>& s, std::size_t i, std::size_t n) make_multi_index(const hip_shape<N>& s, index_int i, index_int n)
{ {
return {s.multi(i), s.multi(n)}; return {s.multi(i), s.multi(n)};
} }
template <std::size_t N> template <index_int N>
MIGRAPHX_DEVICE_CONSTEXPR multi_index<N> MIGRAPHX_DEVICE_CONSTEXPR multi_index<N>
make_multi_index(const hip_shape<N>& s, std::size_t i, const hip_array<std::size_t, N>& n) make_multi_index(const hip_shape<N>& s, index_int i, const hip_array<index_int, N>& n)
{ {
return {s.multi(i), n}; return {s.multi(i), n};
} }
template <std::size_t N> template <index_int N>
inline auto mi_launch(hipStream_t stream, const hip_shape<N>& s, std::size_t local = 1024) inline auto mi_launch(hipStream_t stream, const hip_shape<N>& s, index_int local = 1024)
{ {
assert(s.standard); assert(s.standard);
assert(s.elements() > 0); assert(s.elements() > 0);
std::size_t n = s.elements(); index_int n = s.elements();
std::size_t groups = (n + local - 1) / local; index_int groups = (n + local - 1) / local;
std::size_t nglobal = std::min<std::size_t>(128, groups) * local; index_int nglobal = std::min<index_int>(128, groups) * local;
assert(groups > 0); assert(groups > 0);
assert(nglobal > 0); assert(nglobal > 0);
......
...@@ -40,7 +40,7 @@ auto nary_nonstandard_nonpacked_impl(hipStream_t stream, F f, argument result, A ...@@ -40,7 +40,7 @@ auto nary_nonstandard_nonpacked_impl(hipStream_t stream, F f, argument result, A
}); });
} }
inline auto create_broadcast_index(std::size_t len, std::size_t stride) inline auto create_broadcast_index(index_int len, index_int stride)
{ {
auto next_stride = stride * len; auto next_stride = stride * len;
auto e_next_stride = encode_divisor(next_stride); auto e_next_stride = encode_divisor(next_stride);
...@@ -83,14 +83,14 @@ void nary_broadcast_vec_impl( ...@@ -83,14 +83,14 @@ void nary_broadcast_vec_impl(
auto bdim_stride = output_shape.strides()[bdim]; auto bdim_stride = output_shape.strides()[bdim];
auto broadcast_idx = create_broadcast_index(bdim_len, bdim_stride); auto broadcast_idx = create_broadcast_index(bdim_len, bdim_stride);
const std::size_t vec_size = 4; const index_int vec_size = 4;
const std::size_t nlocal = 1024; const index_int nlocal = 1024;
const std::size_t nglobal = 256 * nlocal; const index_int nglobal = 256 * nlocal;
const std::size_t bdim_vec_len = bdim_len / vec_size; const index_int bdim_vec_len = bdim_len / vec_size;
hip_vec_visit_all<vec_size>(result, barg, args...)( hip_vec_visit_all<vec_size>(result, barg, args...)(
[&](auto output, auto binput, auto... inputs) { [&](auto output, auto binput, auto... inputs) {
using type = typename decltype(output)::value_type; using type = typename decltype(output)::value_type;
const std::size_t nelements = output.size() / vec_size; const index_int nelements = output.size() / vec_size;
launch(stream, nglobal, nlocal)([=](auto idx) __device__ { launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPHX_DEVICE_SHARED type buffer[2048 / vec_size]; MIGRAPHX_DEVICE_SHARED type buffer[2048 / vec_size];
...@@ -107,7 +107,7 @@ void nary_broadcast_vec_impl( ...@@ -107,7 +107,7 @@ void nary_broadcast_vec_impl(
auto bidx = broadcast_idx(i * vec_size); auto bidx = broadcast_idx(i * vec_size);
auto b = bp[bidx]; auto b = bp[bidx];
auto out = output.data()[i]; auto out = output.data()[i];
for(std::size_t j = 0; j < vec_size; j++) for(index_int j = 0; j < vec_size; j++)
{ {
out[j] = f(inputs.data()[i][j]..., b); out[j] = f(inputs.data()[i][j]..., b);
} }
...@@ -132,9 +132,9 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg ...@@ -132,9 +132,9 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg
auto bdim_stride = output_shape.strides()[bdim]; auto bdim_stride = output_shape.strides()[bdim];
auto broadcast_idx = create_broadcast_index(bdim_len, bdim_stride); auto broadcast_idx = create_broadcast_index(bdim_len, bdim_stride);
const std::size_t nlocal = 1024; const index_int nlocal = 1024;
const std::size_t nglobal = 256 * nlocal; const index_int nglobal = 256 * nlocal;
std::size_t nelements = result.get_shape().elements(); index_int nelements = result.get_shape().elements();
hip_visit_all(result, barg, args...)([&](auto output, auto binput, auto... inputs) { hip_visit_all(result, barg, args...)([&](auto output, auto binput, auto... inputs) {
using type = typename decltype(output)::value_type; using type = typename decltype(output)::value_type;
launch(stream, nglobal, nlocal)([=](auto idx) __device__ { launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
...@@ -175,14 +175,14 @@ void nary_double_broadcast_vec_impl( ...@@ -175,14 +175,14 @@ void nary_double_broadcast_vec_impl(
auto bdim_stride = output_shape.strides()[bdim]; auto bdim_stride = output_shape.strides()[bdim];
auto broadcast_idx = create_broadcast_index(bdim_len, bdim_stride); auto broadcast_idx = create_broadcast_index(bdim_len, bdim_stride);
const std::size_t vec_size = 4; const index_int vec_size = 4;
const std::size_t nlocal = 1024; const index_int nlocal = 1024;
const std::size_t nglobal = 256 * nlocal; const index_int nglobal = 256 * nlocal;
const std::size_t bdim_vec_len = bdim_len / vec_size; const index_int bdim_vec_len = bdim_len / vec_size;
hip_vec_visit_all<vec_size>(result, barg1, barg2, args...)( hip_vec_visit_all<vec_size>(result, barg1, barg2, args...)(
[&](auto output, auto binput1, auto binput2, auto... inputs) { [&](auto output, auto binput1, auto binput2, auto... inputs) {
using type = typename decltype(output)::value_type; using type = typename decltype(output)::value_type;
const std::size_t nelements = output.size() / vec_size; const index_int nelements = output.size() / vec_size;
launch(stream, nglobal, nlocal)([=](auto idx) __device__ { launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPHX_DEVICE_SHARED type buffer[2048 / vec_size]; MIGRAPHX_DEVICE_SHARED type buffer[2048 / vec_size];
...@@ -204,7 +204,7 @@ void nary_double_broadcast_vec_impl( ...@@ -204,7 +204,7 @@ void nary_double_broadcast_vec_impl(
auto b1 = bp[bidx]; auto b1 = bp[bidx];
auto b2 = bp[bidx + bdim_len]; auto b2 = bp[bidx + bdim_len];
auto out = output.data()[i]; auto out = output.data()[i];
for(std::size_t j = 0; j < vec_size; j++) for(index_int j = 0; j < vec_size; j++)
{ {
out[j] = f(inputs.data()[i][j]..., b2, b1); out[j] = f(inputs.data()[i][j]..., b2, b1);
} }
...@@ -233,9 +233,9 @@ void nary_double_broadcast_impl( ...@@ -233,9 +233,9 @@ void nary_double_broadcast_impl(
auto bdim_stride = output_shape.strides()[bdim]; auto bdim_stride = output_shape.strides()[bdim];
auto broadcast_idx = create_broadcast_index(bdim_len, bdim_stride); auto broadcast_idx = create_broadcast_index(bdim_len, bdim_stride);
const std::size_t nlocal = 1024; const index_int nlocal = 1024;
const std::size_t nglobal = 256 * nlocal; const index_int nglobal = 256 * nlocal;
std::size_t nelements = result.get_shape().elements(); index_int nelements = result.get_shape().elements();
hip_visit_all(result, barg1, barg2, args...)( hip_visit_all(result, barg1, barg2, args...)(
[&](auto output, auto binput1, auto binput2, auto... inputs) { [&](auto output, auto binput1, auto binput2, auto... inputs) {
using type = typename decltype(output)::value_type; using type = typename decltype(output)::value_type;
...@@ -270,14 +270,14 @@ void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments. ...@@ -270,14 +270,14 @@ void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments.
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) { visit_all(result, args...)([&](auto output, auto... inputs) {
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>; using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
const std::size_t vec_size = 4; const index_int vec_size = 4;
auto data = pack_vec<4>(device_cast(inputs.data())...); auto data = pack_vec<4>(device_cast(inputs.data())...);
auto* outp = as_vec<4>(device_cast(output.data())); auto* outp = as_vec<4>(device_cast(output.data()));
gs_launch(stream, output_shape.elements() / vec_size)([=](auto i) { gs_launch(stream, output_shape.elements() / vec_size)([=](auto i) {
vec<type, 4> out = outp[i]; vec<type, 4> out = outp[i];
data( data(
[&](auto... xs) { [&](auto... xs) {
for(std::size_t j = 0; j < vec_size; j++) for(index_int j = 0; j < vec_size; j++)
{ {
out[j] = f(xs[j]...); out[j] = f(xs[j]...);
} }
...@@ -292,7 +292,7 @@ template <class F, class... Arguments> ...@@ -292,7 +292,7 @@ template <class F, class... Arguments>
void nary_standard_impl(hipStream_t stream, F f, argument result, Arguments... args) void nary_standard_impl(hipStream_t stream, F f, argument result, Arguments... args)
{ {
MIGRAPHX_TRACE_NARY_FUNCTION MIGRAPHX_TRACE_NARY_FUNCTION
std::size_t nelements = result.get_shape().elements(); index_int nelements = result.get_shape().elements();
hip_pointer_visit_all(result, args...)([&](auto output, auto... inputs) { hip_pointer_visit_all(result, args...)([&](auto output, auto... inputs) {
gs_launch(stream, nelements)([=](auto i) { output[i] = f(inputs[i]...); }); gs_launch(stream, nelements)([=](auto i) { output[i] = f(inputs[i]...); });
}); });
...@@ -331,7 +331,7 @@ auto nary_standard(hipStream_t stream, argument result, Arguments... args) ...@@ -331,7 +331,7 @@ auto nary_standard(hipStream_t stream, argument result, Arguments... args)
template <class... Arguments> template <class... Arguments>
bool broadcastable(bool& divisible_by_4, bool broadcastable(bool& divisible_by_4,
std::size_t max_size, index_int max_size,
const argument& result, const argument& result,
const argument& barg, const argument& barg,
const Arguments&... args) const Arguments&... args)
...@@ -363,7 +363,7 @@ bool broadcastable(bool& divisible_by_4, ...@@ -363,7 +363,7 @@ bool broadcastable(bool& divisible_by_4,
return false; return false;
} }
inline bool broadcastable(bool& divisible_by_4, std::size_t, const argument&, const argument&) inline bool broadcastable(bool& divisible_by_4, index_int, const argument&, const argument&)
{ {
divisible_by_4 = false; divisible_by_4 = false;
return false; return false;
......
...@@ -75,8 +75,8 @@ struct highest ...@@ -75,8 +75,8 @@ struct highest
}; };
#ifdef MIGRAPHX_NO_DPP #ifdef MIGRAPHX_NO_DPP
template <std::size_t N, class Op, class T, class F> template <index_int N, class Op, class T, class F>
__device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f) __device__ auto block_reduce(index idx, Op op, T init, index_int n, F f)
{ {
using type = decltype(f(idx.local)); using type = decltype(f(idx.local));
MIGRAPHX_DEVICE_SHARED type buffer[N]; MIGRAPHX_DEVICE_SHARED type buffer[N];
...@@ -85,9 +85,9 @@ __device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f) ...@@ -85,9 +85,9 @@ __device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f)
buffer[idx.local] = x; buffer[idx.local] = x;
__syncthreads(); __syncthreads();
for(std::size_t s = 1; s < idx.nlocal(); s *= 2) for(index_int s = 1; s < idx.nlocal(); s *= 2)
{ {
const std::size_t index = 2 * s * idx.local; const index_int index = 2 * s * idx.local;
if(index + s < idx.nlocal()) if(index + s < idx.nlocal())
{ {
buffer[index] = op(buffer[index], buffer[index + s]); buffer[index] = op(buffer[index], buffer[index + s]);
...@@ -118,7 +118,7 @@ template <unsigned int DppCtrl, ...@@ -118,7 +118,7 @@ template <unsigned int DppCtrl,
class T> class T>
__device__ T dpp_mov(T& x) __device__ T dpp_mov(T& x)
{ {
static const std::size_t n = sizeof(T) < 4 ? 1 : sizeof(T) / 4; static const index_int n = sizeof(T) < 4 ? 1 : sizeof(T) / 4;
union type union type
{ {
uint32_t reg[n]; uint32_t reg[n];
...@@ -128,7 +128,7 @@ __device__ T dpp_mov(T& x) ...@@ -128,7 +128,7 @@ __device__ T dpp_mov(T& x)
type input{}; type input{};
// cppcheck-suppress unreadVariable // cppcheck-suppress unreadVariable
input.data = x; input.data = x;
for(std::size_t i = 0; i < n; i++) for(index_int i = 0; i < n; i++)
{ {
output.reg[i] = __llvm_amdgcn_move_dpp(input.reg[i], DppCtrl, RowMask, BankMask, BoundCtrl); output.reg[i] = __llvm_amdgcn_move_dpp(input.reg[i], DppCtrl, RowMask, BankMask, BoundCtrl);
} }
...@@ -176,8 +176,8 @@ __device__ inline void dpp_reduce(float& x, sum) ...@@ -176,8 +176,8 @@ __device__ inline void dpp_reduce(float& x, sum)
#endif #endif
} }
template <std::size_t N, class Op, class T, class F> template <index_int N, class Op, class T, class F>
__device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f) __device__ auto block_reduce(index idx, Op op, T init, index_int n, F f)
{ {
using type = decltype(f(idx.local)); using type = decltype(f(idx.local));
MIGRAPHX_DEVICE_SHARED type buffer[N / 64]; MIGRAPHX_DEVICE_SHARED type buffer[N / 64];
...@@ -193,14 +193,14 @@ __device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f) ...@@ -193,14 +193,14 @@ __device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f)
__syncthreads(); __syncthreads();
type y = init; type y = init;
for(std::size_t i = 0; i < idx.nlocal() / 64; i++) for(index_int i = 0; i < idx.nlocal() / 64; i++)
{ {
y = op(y, buffer[i]); y = op(y, buffer[i]);
} }
return y; return y;
} }
#endif #endif
constexpr std::size_t compute_block_size(std::size_t n, std::size_t max_block_size) constexpr index_int compute_block_size(index_int n, index_int max_block_size)
{ {
size_t block_size = 64; size_t block_size = 64;
while(block_size < max_block_size and block_size < n) while(block_size < max_block_size and block_size < n)
...@@ -222,8 +222,8 @@ void reduce_multi_impl(hipStream_t stream, ...@@ -222,8 +222,8 @@ void reduce_multi_impl(hipStream_t stream,
auto nelements = result.get_shape().elements(); auto nelements = result.get_shape().elements();
auto relements = reduce_slice.elements(); auto relements = reduce_slice.elements();
const std::size_t max_block_size = 256; const index_int max_block_size = 256;
const std::size_t block_size = compute_block_size(relements, max_block_size); const index_int block_size = compute_block_size(relements, max_block_size);
gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ { gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ {
const auto out_idx = i / block_size; const auto out_idx = i / block_size;
auto base_idx = output.get_shape().multi(out_idx); auto base_idx = output.get_shape().multi(out_idx);
...@@ -245,13 +245,13 @@ void reduce_standard_impl(hipStream_t stream, ...@@ -245,13 +245,13 @@ void reduce_standard_impl(hipStream_t stream,
T init, T init,
Input read_input, Input read_input,
Output read_output, Output read_output,
std::size_t relements) index_int relements)
{ {
hip_visit_all(result, arg)([&](auto output, auto input) { hip_visit_all(result, arg)([&](auto output, auto input) {
auto nelements = result.get_shape().elements(); auto nelements = result.get_shape().elements();
const std::size_t max_block_size = 256; const index_int max_block_size = 256;
const std::size_t block_size = compute_block_size(relements, max_block_size); const index_int block_size = compute_block_size(relements, max_block_size);
gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ { gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ {
const auto out_idx = i / block_size; const auto out_idx = i / block_size;
const auto base_idx = out_idx * relements; const auto base_idx = out_idx * relements;
...@@ -287,12 +287,12 @@ void reduce(hipStream_t stream, ...@@ -287,12 +287,12 @@ void reduce(hipStream_t stream,
} }
else else
{ {
std::vector<std::size_t> reduce_lens; std::vector<index_int> reduce_lens;
std::transform(output_shape.lens().begin(), std::transform(output_shape.lens().begin(),
output_shape.lens().end(), output_shape.lens().end(),
input_shape.lens().begin(), input_shape.lens().begin(),
std::back_inserter(reduce_lens), std::back_inserter(reduce_lens),
[](auto x, auto y) -> std::size_t { [](auto x, auto y) -> index_int {
if(x == y) if(x == y)
return 1; return 1;
else else
......
...@@ -10,14 +10,14 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,14 +10,14 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
template <std::size_t N> template <index_int N>
struct hip_shape struct hip_shape
{ {
using hip_index = hip_array<std::size_t, N>; using hip_index = hip_array<index_int, N>;
hip_array<std::size_t, N> lens = {}; hip_index lens = {};
hip_array<std::size_t, N> strides = {}; hip_index strides = {};
hip_array<std::size_t, N> divs = {}; hip_array<std::uint64_t, N> divs = {};
bool standard = false; bool standard = false;
__device__ __host__ hip_shape() = default; __device__ __host__ hip_shape() = default;
...@@ -31,34 +31,34 @@ struct hip_shape ...@@ -31,34 +31,34 @@ struct hip_shape
std::transform(s.lens().begin(), s.lens().end(), divs.begin(), &encode_divisor); std::transform(s.lens().begin(), s.lens().end(), divs.begin(), &encode_divisor);
} }
MIGRAPHX_DEVICE_CONSTEXPR std::size_t elements() const { return lens.product(); } MIGRAPHX_DEVICE_CONSTEXPR index_int elements() const { return lens.product(); }
MIGRAPHX_DEVICE_CONSTEXPR std::size_t index(hip_index x) const { return x.dot(strides); } MIGRAPHX_DEVICE_CONSTEXPR index_int index(hip_index x) const { return x.dot(strides); }
MIGRAPHX_DEVICE_CONSTEXPR std::size_t index(std::initializer_list<std::size_t> x) const MIGRAPHX_DEVICE_CONSTEXPR index_int index(std::initializer_list<index_int> x) const
{ {
std::size_t idx = 0; index_int idx = 0;
for(std::size_t i = 0; i < x.size(); i++) for(index_int i = 0; i < x.size(); i++)
idx += *(x.begin() + i) * strides[i]; idx += *(x.begin() + i) * strides[i];
return idx; return idx;
} }
MIGRAPHX_DEVICE_CONSTEXPR std::size_t index(std::size_t i) const MIGRAPHX_DEVICE_CONSTEXPR index_int index(index_int i) const
{ {
if(this->standard) if(this->standard)
return i; return i;
else else
{ {
const std::size_t rank = this->lens.size(); const index_int rank = this->lens.size();
std::size_t s = 1; index_int s = 1;
std::size_t result = 0; index_int result = 0;
for(std::size_t j = 0; j < this->lens.size(); j++) for(index_int j = 0; j < this->lens.size(); j++)
{ {
const std::size_t k = rank - j - 1; const index_int k = rank - j - 1;
const std::size_t stride = this->strides[k]; const index_int stride = this->strides[k];
const std::size_t len = this->lens[k]; const index_int len = this->lens[k];
const std::size_t slen = s * len; const index_int slen = s * len;
const std::size_t idx = (i % slen) / s; const index_int idx = (i % slen) / s;
result += stride * idx; result += stride * idx;
s = slen; s = slen;
} }
...@@ -66,10 +66,10 @@ struct hip_shape ...@@ -66,10 +66,10 @@ struct hip_shape
} }
} }
MIGRAPHX_DEVICE_CONSTEXPR hip_index multi(std::size_t idx) const MIGRAPHX_DEVICE_CONSTEXPR hip_index multi(index_int idx) const
{ {
hip_index result; hip_index result;
std::size_t tidx = idx; index_int tidx = idx;
for(std::ptrdiff_t is = result.size() - 1; is > 0; is--) for(std::ptrdiff_t is = result.size() - 1; is > 0; is--)
{ {
// result[is] = tidx % lens[is]; // result[is] = tidx % lens[is];
...@@ -83,7 +83,7 @@ struct hip_shape ...@@ -83,7 +83,7 @@ struct hip_shape
} }
}; };
template <std::size_t N> template <index_int N>
hip_shape<N> make_hip_shape(const shape& x) hip_shape<N> make_hip_shape(const shape& x)
{ {
return x; return x;
......
...@@ -8,10 +8,10 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -8,10 +8,10 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
template <std::size_t NDim> template <index_int NDim>
using hip_tensor_index = hip_array<std::size_t, NDim>; using hip_tensor_index = hip_array<index_int, NDim>;
template <std::size_t NDim> template <index_int NDim>
struct hip_tensor_descriptor struct hip_tensor_descriptor
{ {
__device__ __host__ hip_tensor_descriptor() = default; __device__ __host__ hip_tensor_descriptor() = default;
...@@ -22,11 +22,11 @@ struct hip_tensor_descriptor ...@@ -22,11 +22,11 @@ struct hip_tensor_descriptor
std::copy(s.strides().begin(), s.strides().end(), strides); std::copy(s.strides().begin(), s.strides().end(), strides);
} }
__device__ __host__ hip_tensor_index<NDim> multi(std::size_t idx) const __device__ __host__ hip_tensor_index<NDim> multi(index_int idx) const
{ {
hip_tensor_index<NDim> result{}; hip_tensor_index<NDim> result{};
std::size_t tidx = idx; index_int tidx = idx;
for(std::size_t is = 0; is < NDim; is++) for(index_int is = 0; is < NDim; is++)
{ {
result[is] = tidx / strides[is]; result[is] = tidx / strides[is];
tidx = tidx % strides[is]; tidx = tidx % strides[is];
...@@ -34,15 +34,15 @@ struct hip_tensor_descriptor ...@@ -34,15 +34,15 @@ struct hip_tensor_descriptor
return result; return result;
} }
__device__ __host__ std::size_t linear(hip_tensor_index<NDim> s) const __device__ __host__ index_int linear(hip_tensor_index<NDim> s) const
{ {
std::size_t idx = 0; index_int idx = 0;
for(std::size_t i = 0; i < NDim; i++) for(index_int i = 0; i < NDim; i++)
idx += s[i] * strides[i]; idx += s[i] * strides[i];
return idx; return idx;
} }
std::size_t lens[NDim] = {}; index_int lens[NDim] = {};
std::size_t strides[NDim] = {}; index_int strides[NDim] = {};
}; };
} // namespace device } // namespace device
......
...@@ -9,7 +9,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -9,7 +9,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
template <class T, std::size_t N> template <class T, index_int N>
struct hip_tensor_view struct hip_tensor_view
{ {
using value_type = T; using value_type = T;
...@@ -20,7 +20,7 @@ struct hip_tensor_view ...@@ -20,7 +20,7 @@ struct hip_tensor_view
MIGRAPHX_DEVICE_CONSTEXPR const hip_shape<N>& get_shape() const { return s; } MIGRAPHX_DEVICE_CONSTEXPR const hip_shape<N>& get_shape() const { return s; }
MIGRAPHX_DEVICE_CONSTEXPR std::size_t size() const { return s.elements(); } MIGRAPHX_DEVICE_CONSTEXPR index_int size() const { return s.elements(); }
MIGRAPHX_DEVICE_CONSTEXPR value_type* data() const { return d; } MIGRAPHX_DEVICE_CONSTEXPR value_type* data() const { return d; }
...@@ -39,13 +39,13 @@ struct hip_tensor_view ...@@ -39,13 +39,13 @@ struct hip_tensor_view
hip_shape<N> s{}; hip_shape<N> s{};
}; };
template <std::size_t N, class T> template <index_int N, class T>
hip_tensor_view<T, N> make_hip_view(const shape& s, T* x) hip_tensor_view<T, N> make_hip_view(const shape& s, T* x)
{ {
return {x, s}; return {x, s};
} }
template <std::size_t N, class T> template <index_int N, class T>
hip_tensor_view<T, N> make_hip_view(tensor_view<T> x) hip_tensor_view<T, N> make_hip_view(tensor_view<T> x)
{ {
return {x}; return {x};
......
...@@ -18,33 +18,35 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -18,33 +18,35 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
using index_int = std::uint32_t;
#define MIGRAPHX_DEVICE_CONSTEXPR constexpr __device__ __host__ // NOLINT #define MIGRAPHX_DEVICE_CONSTEXPR constexpr __device__ __host__ // NOLINT
template <class T, std::size_t N> template <class T, index_int N>
using vec = T __attribute__((ext_vector_type(N))); using vec = T __attribute__((ext_vector_type(N)));
template <std::size_t N, class T> template <index_int N, class T>
__device__ __host__ T* as_pointer(vec<T, N>* x) __device__ __host__ T* as_pointer(vec<T, N>* x)
{ {
return reinterpret_cast<T*>(x); return reinterpret_cast<T*>(x);
} }
template <std::size_t N, class T> template <index_int N, class T>
__device__ __host__ vec<T, N>* as_vec(T* x) __device__ __host__ vec<T, N>* as_vec(T* x)
{ {
return reinterpret_cast<vec<T, N>*>(x); return reinterpret_cast<vec<T, N>*>(x);
} }
template <std::size_t N, class T> template <index_int N, class T>
tensor_view<vec<T, N>> as_vec(tensor_view<T> x) tensor_view<vec<T, N>> as_vec(tensor_view<T> x)
{ {
return {x.get_shape(), as_vec<N>(x.data())}; return {x.get_shape(), as_vec<N>(x.data())};
} }
template <std::size_t N, class... Ts> template <index_int N, class... Ts>
auto pack_vec(Ts... xs) auto pack_vec(Ts... xs)
{ {
return [=](auto f, std::size_t n) { return f(as_vec<N>(xs)[n]...); }; return [=](auto f, index_int n) { return f(as_vec<N>(xs)[n]...); };
} }
using gpu_half = __fp16; using gpu_half = __fp16;
...@@ -56,7 +58,7 @@ struct device_type ...@@ -56,7 +58,7 @@ struct device_type
using type = T; using type = T;
}; };
template <class T, std::size_t N> template <class T, index_int N>
struct device_type<vec<T, N>> struct device_type<vec<T, N>>
{ {
using type = vec<typename device_type<T>::type, N>; using type = vec<typename device_type<T>::type, N>;
......
...@@ -10,11 +10,11 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,11 +10,11 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
template <class T, std::size_t N> template <class T, index_int N>
struct hip_vector struct hip_vector
{ {
MIGRAPHX_DEVICE_CONSTEXPR hip_vector() = default; MIGRAPHX_DEVICE_CONSTEXPR hip_vector() = default;
MIGRAPHX_DEVICE_CONSTEXPR hip_vector(std::size_t s) : len(s) {} MIGRAPHX_DEVICE_CONSTEXPR hip_vector(index_int s) : len(s) {}
template <class Iterator> template <class Iterator>
__device__ __host__ hip_vector(Iterator start, Iterator last) __device__ __host__ hip_vector(Iterator start, Iterator last)
{ {
...@@ -28,8 +28,8 @@ struct hip_vector ...@@ -28,8 +28,8 @@ struct hip_vector
len = x.size(); len = x.size();
} }
MIGRAPHX_DEVICE_CONSTEXPR T& operator[](std::size_t i) { return d[i]; } MIGRAPHX_DEVICE_CONSTEXPR T& operator[](index_int i) { return d[i]; }
MIGRAPHX_DEVICE_CONSTEXPR const T& operator[](std::size_t i) const { return d[i]; } MIGRAPHX_DEVICE_CONSTEXPR const T& operator[](index_int i) const { return d[i]; }
MIGRAPHX_DEVICE_CONSTEXPR T& front() { return d[0]; } MIGRAPHX_DEVICE_CONSTEXPR T& front() { return d[0]; }
MIGRAPHX_DEVICE_CONSTEXPR const T& front() const { return d[0]; } MIGRAPHX_DEVICE_CONSTEXPR const T& front() const { return d[0]; }
...@@ -40,7 +40,7 @@ struct hip_vector ...@@ -40,7 +40,7 @@ struct hip_vector
MIGRAPHX_DEVICE_CONSTEXPR T* data() { return d; } MIGRAPHX_DEVICE_CONSTEXPR T* data() { return d; }
MIGRAPHX_DEVICE_CONSTEXPR const T* data() const { return d; } MIGRAPHX_DEVICE_CONSTEXPR const T* data() const { return d; }
MIGRAPHX_DEVICE_CONSTEXPR std::size_t size() const { return len; } MIGRAPHX_DEVICE_CONSTEXPR index_int size() const { return len; }
MIGRAPHX_DEVICE_CONSTEXPR T* begin() { return d; } MIGRAPHX_DEVICE_CONSTEXPR T* begin() { return d; }
MIGRAPHX_DEVICE_CONSTEXPR const T* begin() const { return d; } MIGRAPHX_DEVICE_CONSTEXPR const T* begin() const { return d; }
...@@ -56,11 +56,11 @@ struct hip_vector ...@@ -56,11 +56,11 @@ struct hip_vector
} }
private: private:
T d[N] = {}; T d[N] = {};
std::size_t len = 0; index_int len = 0;
}; };
template <std::size_t N, class T> template <index_int N, class T>
hip_vector<T, N> to_hip_vector(const std::vector<T>& x) hip_vector<T, N> to_hip_vector(const std::vector<T>& x)
{ {
hip_vector<T, N> result(x.size()); hip_vector<T, N> result(x.size());
......
...@@ -10,33 +10,33 @@ namespace gpu { ...@@ -10,33 +10,33 @@ namespace gpu {
namespace device { namespace device {
template <class F> template <class F>
void visit_tensor_size(std::size_t n, F f) void visit_tensor_size(index_int n, F f)
{ {
switch(n) switch(n)
{ {
case 1: case 1:
{ {
f(std::integral_constant<std::size_t, 1>{}); f(std::integral_constant<index_int, 1>{});
break; break;
} }
case 2: case 2:
{ {
f(std::integral_constant<std::size_t, 2>{}); f(std::integral_constant<index_int, 2>{});
break; break;
} }
case 3: case 3:
{ {
f(std::integral_constant<std::size_t, 3>{}); f(std::integral_constant<index_int, 3>{});
break; break;
} }
case 4: case 4:
{ {
f(std::integral_constant<std::size_t, 4>{}); f(std::integral_constant<index_int, 4>{});
break; break;
} }
case 5: case 5:
{ {
f(std::integral_constant<std::size_t, 5>{}); f(std::integral_constant<index_int, 5>{});
break; break;
} }
default: throw std::runtime_error("Unknown tensor size"); default: throw std::runtime_error("Unknown tensor size");
...@@ -58,9 +58,9 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs) ...@@ -58,9 +58,9 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
if(!std::all_of( if(!std::all_of(
types.begin(), types.end(), [&](migraphx::shape::type_t t) { return t == s.type(); })) types.begin(), types.end(), [&](migraphx::shape::type_t t) { return t == s.type(); }))
MIGRAPHX_THROW("Types must be the same"); MIGRAPHX_THROW("Types must be the same");
std::initializer_list<std::size_t> ranks = {get_shape(xs).lens().size()...}; std::initializer_list<index_int> ranks = {
if(!std::all_of( static_cast<index_int>(get_shape(xs).lens().size())...};
ranks.begin(), ranks.end(), [&](std::size_t r) { return r == s.lens().size(); })) if(!std::all_of(ranks.begin(), ranks.end(), [&](index_int r) { return r == s.lens().size(); }))
MIGRAPHX_THROW("Ranks must be the same"); MIGRAPHX_THROW("Ranks must be the same");
visit_tensor_size(s.lens().size(), visit_tensor_size(s.lens().size(),
[&](auto ndim) { s.visit_type([&](auto as) { v(f(xs, ndim, as)...); }); }); [&](auto ndim) { s.visit_type([&](auto as) { v(f(xs, ndim, as)...); }); });
...@@ -69,9 +69,9 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs) ...@@ -69,9 +69,9 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
template <class V, class F, class... Ts> template <class V, class F, class... Ts>
void hip_visit_views_impl(const shape& s, F f, V&& v, Ts&&... xs) void hip_visit_views_impl(const shape& s, F f, V&& v, Ts&&... xs)
{ {
std::initializer_list<std::size_t> ranks = {get_shape(xs).lens().size()...}; std::initializer_list<index_int> ranks = {
if(!std::all_of( static_cast<index_int>(get_shape(xs).lens().size())...};
ranks.begin(), ranks.end(), [&](std::size_t r) { return r == s.lens().size(); })) if(!std::all_of(ranks.begin(), ranks.end(), [&](index_int r) { return r == s.lens().size(); }))
MIGRAPHX_THROW("Ranks must be the same"); MIGRAPHX_THROW("Ranks must be the same");
visit_tensor_size(s.lens().size(), [&](auto ndim) { v(f(xs, ndim)...); }); visit_tensor_size(s.lens().size(), [&](auto ndim) { v(f(xs, ndim)...); });
} }
...@@ -132,7 +132,7 @@ auto hip_visit_all(T&& x, Ts&&... xs) ...@@ -132,7 +132,7 @@ auto hip_visit_all(T&& x, Ts&&... xs)
}; };
} }
template <std::size_t N, class T, class... Ts> template <index_int N, class T, class... Ts>
auto hip_vec_visit_all(T&& x, Ts&&... xs) auto hip_vec_visit_all(T&& x, Ts&&... xs)
{ {
return [&](auto f) { return [&](auto f) {
......
...@@ -14,15 +14,15 @@ namespace device { ...@@ -14,15 +14,15 @@ namespace device {
void logsoftmax(hipStream_t stream, const argument& result, const argument& arg, int axis) void logsoftmax(hipStream_t stream, const argument& result, const argument& arg, int axis)
{ {
auto lens = result.get_shape().lens(); auto lens = result.get_shape().lens();
auto batch_lens = lens; auto batch_lens = lens;
std::size_t batch_item_num = lens[axis]; index_int batch_item_num = lens[axis];
batch_lens[axis] = 1; batch_lens[axis] = 1;
migraphx::shape batch_shape{result.get_shape().type(), batch_lens}; migraphx::shape batch_shape{result.get_shape().type(), batch_lens};
hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) { hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) {
const std::size_t max_block_size = 256; const index_int max_block_size = 256;
const std::size_t block_size = compute_block_size(batch_item_num, max_block_size); const index_int block_size = compute_block_size(batch_item_num, max_block_size);
gs_launch(stream, gs_launch(stream,
batch_shape.elements() * block_size, batch_shape.elements() * block_size,
block_size)([=](auto i, auto idx) __device__ { block_size)([=](auto i, auto idx) __device__ {
......
...@@ -8,7 +8,7 @@ namespace device { ...@@ -8,7 +8,7 @@ namespace device {
void reduce_mean(hipStream_t stream, const argument& result, const argument& arg) void reduce_mean(hipStream_t stream, const argument& result, const argument& arg)
{ {
std::size_t item_num = arg.get_shape().elements() / result.get_shape().elements(); index_int item_num = arg.get_shape().elements() / result.get_shape().elements();
reduce(stream, result, arg, sum{}, 0, id{}, mean{item_num}); reduce(stream, result, arg, sum{}, 0, id{}, mean{item_num});
} }
......
...@@ -15,15 +15,15 @@ namespace device { ...@@ -15,15 +15,15 @@ namespace device {
void softmax(hipStream_t stream, const argument& result, const argument& arg, int axis) void softmax(hipStream_t stream, const argument& result, const argument& arg, int axis)
{ {
auto lens = result.get_shape().lens(); auto lens = result.get_shape().lens();
auto batch_lens = lens; auto batch_lens = lens;
std::size_t batch_item_num = lens[axis]; index_int batch_item_num = lens[axis];
batch_lens[axis] = 1; batch_lens[axis] = 1;
migraphx::shape batch_shape{result.get_shape().type(), batch_lens}; migraphx::shape batch_shape{result.get_shape().type(), batch_lens};
hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) { hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) {
const std::size_t max_block_size = 256; const index_int max_block_size = 256;
const std::size_t block_size = compute_block_size(batch_item_num, max_block_size); const index_int block_size = compute_block_size(batch_item_num, max_block_size);
gs_launch(stream, gs_launch(stream,
batch_shape.elements() * block_size, batch_shape.elements() * block_size,
block_size)([=](auto i, auto idx) __device__ { block_size)([=](auto i, auto idx) __device__ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment