Commit 9f755219 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

revert nary

parent 08ac24cf
......@@ -87,7 +87,7 @@ void nary_broadcast_vec_impl(
const index_int vec_size = 4;
const index_int nlocal = 1024;
const index_int nglobal = 512 * nlocal;
const index_int nglobal = 256 * nlocal;
const index_int bdim_vec_len = bdim_len / vec_size;
hip_vec_visit_all<vec_size>(result, barg, args...)(
[&](auto output, auto binput, auto... inputs) {
......@@ -134,7 +134,7 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg
auto broadcast_idx = create_broadcast_index(bdim_len, bdim_stride);
const index_int nlocal = 1024;
const index_int nglobal = 512 * nlocal;
const index_int nglobal = 256 * nlocal;
index_int nelements = result.get_shape().elements();
hip_visit_all(result, barg, args...)([&](auto output, auto binput, auto... inputs) {
using type = typename decltype(output)::value_type;
......@@ -178,7 +178,7 @@ void nary_double_broadcast_vec_impl(
const index_int vec_size = 4;
const index_int nlocal = 1024;
const index_int nglobal = 512 * nlocal;
const index_int nglobal = 256 * nlocal;
const index_int bdim_vec_len = bdim_len / vec_size;
hip_vec_visit_all<vec_size>(result, barg1, barg2, args...)(
[&](auto output, auto binput1, auto binput2, auto... inputs) {
......@@ -234,7 +234,7 @@ void nary_double_broadcast_impl(
auto broadcast_idx = create_broadcast_index(bdim_len, bdim_stride);
const index_int nlocal = 1024;
const index_int nglobal = 512 * nlocal;
const index_int nglobal = 256 * nlocal;
index_int nelements = result.get_shape().elements();
hip_visit_all(result, barg1, barg2, args...)(
[&](auto output, auto binput1, auto binput2, auto... inputs) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment