Commit 3a848f0d authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into doc2

parents 64e8e30a d1e945da
...@@ -78,8 +78,9 @@ inline auto gs_launch(hipStream_t stream, index_int n, index_int local = 1024) ...@@ -78,8 +78,9 @@ inline auto gs_launch(hipStream_t stream, index_int n, index_int local = 1024)
index_int nglobal = std::min<index_int>(256, groups) * local; index_int nglobal = std::min<index_int>(256, groups) * local;
return [=](auto f) { return [=](auto f) {
launch(stream, nglobal, local)( launch(stream, nglobal, local)([=](auto idx) __device__ {
[=](auto idx) { idx.global_stride(n, [&](auto i) { gs_invoke(f, i, idx); }); }); idx.global_stride(n, [&](auto i) { gs_invoke(f, i, idx); });
});
}; };
} }
......
...@@ -95,7 +95,7 @@ inline auto mi_launch(hipStream_t stream, const hip_shape<N>& global, index_int ...@@ -95,7 +95,7 @@ inline auto mi_launch(hipStream_t stream, const hip_shape<N>& global, index_int
auto nglobal = global.index(nglobal_multi); auto nglobal = global.index(nglobal_multi);
return [=](auto f) { return [=](auto f) {
launch(stream, nglobal, nlocal)([=](auto idx) { launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
auto midx = make_multi_index(global, idx.global, nglobal_multi); auto midx = make_multi_index(global, idx.global, nglobal_multi);
f(idx, midx.for_stride(global.lens)); f(idx, midx.for_stride(global.lens));
}); });
......
...@@ -36,7 +36,8 @@ auto nary_nonstandard_nonpacked_impl(hipStream_t stream, F f, argument result, A ...@@ -36,7 +36,8 @@ auto nary_nonstandard_nonpacked_impl(hipStream_t stream, F f, argument result, A
MIGRAPHX_TRACE_NARY_FUNCTION MIGRAPHX_TRACE_NARY_FUNCTION
shape s{result.get_shape().type(), result.get_shape().lens()}; shape s{result.get_shape().type(), result.get_shape().lens()};
hip_visit_all(s, result, args...)([&](auto standard_shape, auto output, auto... inputs) { hip_visit_all(s, result, args...)([&](auto standard_shape, auto output, auto... inputs) {
mi_gs_launch(stream, standard_shape)([=](auto idx) { output[idx] = f(inputs[idx]...); }); mi_gs_launch(stream,
standard_shape)([=](auto idx) __device__ { output[idx] = f(inputs[idx]...); });
}); });
} }
...@@ -45,7 +46,7 @@ inline auto create_broadcast_index(index_int len, index_int stride) ...@@ -45,7 +46,7 @@ inline auto create_broadcast_index(index_int len, index_int stride)
auto next_stride = stride * len; auto next_stride = stride * len;
auto e_next_stride = encode_divisor(next_stride); auto e_next_stride = encode_divisor(next_stride);
auto e_stride = encode_divisor(stride); auto e_stride = encode_divisor(stride);
return [=](auto i) { return [=](auto i) __device__ {
// ( i % next_stride) / stride // ( i % next_stride) / stride
return fast_div(i, e_stride) - len * fast_div(i, e_next_stride); return fast_div(i, e_stride) - len * fast_div(i, e_next_stride);
}; };
...@@ -61,10 +62,10 @@ auto nary_nonstandard_packed_impl(hipStream_t stream, ...@@ -61,10 +62,10 @@ auto nary_nonstandard_packed_impl(hipStream_t stream,
auto arg_shape = make_array(args...).front().get_shape(); auto arg_shape = make_array(args...).front().get_shape();
auto perm = find_permutation(arg_shape); auto perm = find_permutation(arg_shape);
auto s = reorder_shape(arg_shape, perm); auto s = reorder_shape(arg_shape, perm);
hip_visit_all(s, hip_visit_all(s, result.reshape(reorder_shape(result.get_shape(), perm)), args.reshape(s)...)(
result.reshape(reorder_shape(result.get_shape(), perm)), [&](auto standard_shape, auto output, auto... inputs) {
args.reshape(s)...)([&](auto standard_shape, auto output, auto... inputs) { mi_gs_launch(stream, standard_shape)(
mi_gs_launch(stream, standard_shape)([=](auto idx) { output[idx] = f(inputs[idx]...); }); [=](auto idx) __device__ { output[idx] = f(inputs[idx]...); });
}); });
} }
...@@ -93,7 +94,6 @@ void nary_broadcast_vec_impl( ...@@ -93,7 +94,6 @@ void nary_broadcast_vec_impl(
using type = typename decltype(output)::value_type; using type = typename decltype(output)::value_type;
const index_int nelements = output.size() / vec_size; const index_int nelements = output.size() / vec_size;
launch(stream, nglobal, nlocal)([=](auto idx) __device__ { launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPHX_DEVICE_SHARED type buffer[2048 / vec_size]; MIGRAPHX_DEVICE_SHARED type buffer[2048 / vec_size];
// Load bias into LDS // Load bias into LDS
for(size_t i = idx.local; i < bdim_vec_len; i += nlocal) for(size_t i = idx.local; i < bdim_vec_len; i += nlocal)
...@@ -185,7 +185,6 @@ void nary_double_broadcast_vec_impl( ...@@ -185,7 +185,6 @@ void nary_double_broadcast_vec_impl(
using type = typename decltype(output)::value_type; using type = typename decltype(output)::value_type;
const index_int nelements = output.size() / vec_size; const index_int nelements = output.size() / vec_size;
launch(stream, nglobal, nlocal)([=](auto idx) __device__ { launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPHX_DEVICE_SHARED type buffer[2048 / vec_size]; MIGRAPHX_DEVICE_SHARED type buffer[2048 / vec_size];
// Load bias into LDS // Load bias into LDS
for(size_t i = idx.local; i < bdim_vec_len; i += nlocal) for(size_t i = idx.local; i < bdim_vec_len; i += nlocal)
...@@ -274,7 +273,7 @@ void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments. ...@@ -274,7 +273,7 @@ void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments.
const index_int vec_size = 4; const index_int vec_size = 4;
auto data = pack_vec<4>(device_cast(inputs.data())...); auto data = pack_vec<4>(device_cast(inputs.data())...);
auto* outp = as_vec<4>(device_cast(output.data())); auto* outp = as_vec<4>(device_cast(output.data()));
gs_launch(stream, output_shape.elements() / vec_size)([=](auto i) { gs_launch(stream, output_shape.elements() / vec_size)([=](auto i) __device__ {
vec<type, 4> out = outp[i]; vec<type, 4> out = outp[i];
data( data(
[&](auto... xs) { [&](auto... xs) {
...@@ -295,7 +294,7 @@ void nary_standard_impl(hipStream_t stream, F f, argument result, Arguments... a ...@@ -295,7 +294,7 @@ void nary_standard_impl(hipStream_t stream, F f, argument result, Arguments... a
MIGRAPHX_TRACE_NARY_FUNCTION MIGRAPHX_TRACE_NARY_FUNCTION
index_int nelements = result.get_shape().elements(); index_int nelements = result.get_shape().elements();
hip_pointer_visit_all(result, args...)([&](auto output, auto... inputs) { hip_pointer_visit_all(result, args...)([&](auto output, auto... inputs) {
gs_launch(stream, nelements)([=](auto i) { output[i] = f(inputs[i]...); }); gs_launch(stream, nelements)([=](auto i) __device__ { output[i] = f(inputs[i]...); });
}); });
} }
......
...@@ -20,6 +20,15 @@ struct sum ...@@ -20,6 +20,15 @@ struct sum
} }
}; };
struct product
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x, U y) const
{
return x * y;
}
};
struct id struct id
{ {
template <class T> template <class T>
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment