"vscode:/vscode.git/clone" did not exist on "34de04a61ccb9959fdf1d09136e66f5db285f9bb"
Commit 9f755219 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

revert nary

parent 08ac24cf
...@@ -87,7 +87,7 @@ void nary_broadcast_vec_impl( ...@@ -87,7 +87,7 @@ void nary_broadcast_vec_impl(
const index_int vec_size = 4; const index_int vec_size = 4;
const index_int nlocal = 1024; const index_int nlocal = 1024;
const index_int nglobal = 512 * nlocal; const index_int nglobal = 256 * nlocal;
const index_int bdim_vec_len = bdim_len / vec_size; const index_int bdim_vec_len = bdim_len / vec_size;
hip_vec_visit_all<vec_size>(result, barg, args...)( hip_vec_visit_all<vec_size>(result, barg, args...)(
[&](auto output, auto binput, auto... inputs) { [&](auto output, auto binput, auto... inputs) {
...@@ -134,7 +134,7 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg ...@@ -134,7 +134,7 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg
auto broadcast_idx = create_broadcast_index(bdim_len, bdim_stride); auto broadcast_idx = create_broadcast_index(bdim_len, bdim_stride);
const index_int nlocal = 1024; const index_int nlocal = 1024;
const index_int nglobal = 512 * nlocal; const index_int nglobal = 256 * nlocal;
index_int nelements = result.get_shape().elements(); index_int nelements = result.get_shape().elements();
hip_visit_all(result, barg, args...)([&](auto output, auto binput, auto... inputs) { hip_visit_all(result, barg, args...)([&](auto output, auto binput, auto... inputs) {
using type = typename decltype(output)::value_type; using type = typename decltype(output)::value_type;
...@@ -178,7 +178,7 @@ void nary_double_broadcast_vec_impl( ...@@ -178,7 +178,7 @@ void nary_double_broadcast_vec_impl(
const index_int vec_size = 4; const index_int vec_size = 4;
const index_int nlocal = 1024; const index_int nlocal = 1024;
const index_int nglobal = 512 * nlocal; const index_int nglobal = 256 * nlocal;
const index_int bdim_vec_len = bdim_len / vec_size; const index_int bdim_vec_len = bdim_len / vec_size;
hip_vec_visit_all<vec_size>(result, barg1, barg2, args...)( hip_vec_visit_all<vec_size>(result, barg1, barg2, args...)(
[&](auto output, auto binput1, auto binput2, auto... inputs) { [&](auto output, auto binput1, auto binput2, auto... inputs) {
...@@ -234,7 +234,7 @@ void nary_double_broadcast_impl( ...@@ -234,7 +234,7 @@ void nary_double_broadcast_impl(
auto broadcast_idx = create_broadcast_index(bdim_len, bdim_stride); auto broadcast_idx = create_broadcast_index(bdim_len, bdim_stride);
const index_int nlocal = 1024; const index_int nlocal = 1024;
const index_int nglobal = 512 * nlocal; const index_int nglobal = 256 * nlocal;
index_int nelements = result.get_shape().elements(); index_int nelements = result.get_shape().elements();
hip_visit_all(result, barg1, barg2, args...)( hip_visit_all(result, barg1, barg2, args...)(
[&](auto output, auto binput1, auto binput2, auto... inputs) { [&](auto output, auto binput1, auto binput2, auto... inputs) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment