Commit 3a848f0d authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into doc2

parents 64e8e30a d1e945da
......@@ -78,8 +78,9 @@ inline auto gs_launch(hipStream_t stream, index_int n, index_int local = 1024)
index_int nglobal = std::min<index_int>(256, groups) * local;
return [=](auto f) {
launch(stream, nglobal, local)(
[=](auto idx) { idx.global_stride(n, [&](auto i) { gs_invoke(f, i, idx); }); });
launch(stream, nglobal, local)([=](auto idx) __device__ {
idx.global_stride(n, [&](auto i) { gs_invoke(f, i, idx); });
});
};
}
......
......@@ -95,7 +95,7 @@ inline auto mi_launch(hipStream_t stream, const hip_shape<N>& global, index_int
auto nglobal = global.index(nglobal_multi);
return [=](auto f) {
launch(stream, nglobal, nlocal)([=](auto idx) {
launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
auto midx = make_multi_index(global, idx.global, nglobal_multi);
f(idx, midx.for_stride(global.lens));
});
......
......@@ -36,7 +36,8 @@ auto nary_nonstandard_nonpacked_impl(hipStream_t stream, F f, argument result, A
MIGRAPHX_TRACE_NARY_FUNCTION
shape s{result.get_shape().type(), result.get_shape().lens()};
hip_visit_all(s, result, args...)([&](auto standard_shape, auto output, auto... inputs) {
mi_gs_launch(stream, standard_shape)([=](auto idx) { output[idx] = f(inputs[idx]...); });
mi_gs_launch(stream,
standard_shape)([=](auto idx) __device__ { output[idx] = f(inputs[idx]...); });
});
}
......@@ -45,7 +46,7 @@ inline auto create_broadcast_index(index_int len, index_int stride)
auto next_stride = stride * len;
auto e_next_stride = encode_divisor(next_stride);
auto e_stride = encode_divisor(stride);
return [=](auto i) {
return [=](auto i) __device__ {
// ( i % next_stride) / stride
return fast_div(i, e_stride) - len * fast_div(i, e_next_stride);
};
......@@ -61,10 +62,10 @@ auto nary_nonstandard_packed_impl(hipStream_t stream,
auto arg_shape = make_array(args...).front().get_shape();
auto perm = find_permutation(arg_shape);
auto s = reorder_shape(arg_shape, perm);
hip_visit_all(s,
result.reshape(reorder_shape(result.get_shape(), perm)),
args.reshape(s)...)([&](auto standard_shape, auto output, auto... inputs) {
mi_gs_launch(stream, standard_shape)([=](auto idx) { output[idx] = f(inputs[idx]...); });
hip_visit_all(s, result.reshape(reorder_shape(result.get_shape(), perm)), args.reshape(s)...)(
[&](auto standard_shape, auto output, auto... inputs) {
mi_gs_launch(stream, standard_shape)(
[=](auto idx) __device__ { output[idx] = f(inputs[idx]...); });
});
}
......@@ -93,7 +94,6 @@ void nary_broadcast_vec_impl(
using type = typename decltype(output)::value_type;
const index_int nelements = output.size() / vec_size;
launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPHX_DEVICE_SHARED type buffer[2048 / vec_size];
// Load bias into LDS
for(size_t i = idx.local; i < bdim_vec_len; i += nlocal)
......@@ -185,7 +185,6 @@ void nary_double_broadcast_vec_impl(
using type = typename decltype(output)::value_type;
const index_int nelements = output.size() / vec_size;
launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPHX_DEVICE_SHARED type buffer[2048 / vec_size];
// Load bias into LDS
for(size_t i = idx.local; i < bdim_vec_len; i += nlocal)
......@@ -274,7 +273,7 @@ void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments.
const index_int vec_size = 4;
auto data = pack_vec<4>(device_cast(inputs.data())...);
auto* outp = as_vec<4>(device_cast(output.data()));
gs_launch(stream, output_shape.elements() / vec_size)([=](auto i) {
gs_launch(stream, output_shape.elements() / vec_size)([=](auto i) __device__ {
vec<type, 4> out = outp[i];
data(
[&](auto... xs) {
......@@ -295,7 +294,7 @@ void nary_standard_impl(hipStream_t stream, F f, argument result, Arguments... a
MIGRAPHX_TRACE_NARY_FUNCTION
index_int nelements = result.get_shape().elements();
hip_pointer_visit_all(result, args...)([&](auto output, auto... inputs) {
gs_launch(stream, nelements)([=](auto i) { output[i] = f(inputs[i]...); });
gs_launch(stream, nelements)([=](auto i) __device__ { output[i] = f(inputs[i]...); });
});
}
......
......@@ -20,6 +20,15 @@ struct sum
}
};
struct product
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x, U y) const
{
return x * y;
}
};
struct id
{
template <class T>
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment