Commit a7b934bf authored by Paul's avatar Paul
Browse files

Fix bug with vector loads

parent 972cfffe
...@@ -19,6 +19,12 @@ vec4<T>* as_vec4(T* x) ...@@ -19,6 +19,12 @@ vec4<T>* as_vec4(T* x)
return reinterpret_cast<vec4<T>*>(x); return reinterpret_cast<vec4<T>*>(x);
} }
template <class... Ts>
auto pack_vec4(Ts... xs)
{
return [=](auto f, std::size_t n) { return f(as_vec4(xs)[n]...); };
}
template <class F, class... Arguments> template <class F, class... Arguments>
auto nary_nonstandard_impl(F f, argument result, Arguments... args) auto nary_nonstandard_impl(F f, argument result, Arguments... args)
{ {
...@@ -59,31 +65,33 @@ inline auto binary_broadcast(argument result, argument arg1, argument arg2) ...@@ -59,31 +65,33 @@ inline auto binary_broadcast(argument result, argument arg1, argument arg2)
visit_all(result, arg1, arg2)([&](auto output, auto input1, auto input2) { visit_all(result, arg1, arg2)([&](auto output, auto input1, auto input2) {
using type = std::remove_cv_t<typename decltype(output)::value_type>; using type = std::remove_cv_t<typename decltype(output)::value_type>;
auto* xp = as_vec4(input1.data()); auto* xp = as_vec4(input1.data());
auto* yp = input2.data(); auto* yp = as_vec4(input2.data());
auto* outp = as_vec4(output.data()); auto* outp = as_vec4(output.data());
const std::size_t vec_size = 4;
const std::size_t nlocal = 1024; const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal; const std::size_t nglobal = 256 * nlocal;
const std::size_t n = output.size() / 4; const std::size_t n = output.size() / vec_size;
const std::size_t bdim_vec_len = bdim_len / vec_size;
launch(nglobal, nlocal)([=](auto idx) __device__ { launch(nglobal, nlocal)([=](auto idx) __device__ {
__shared__ type buffer[2048]; __shared__ vec4<type> buffer[2048];
for(size_t i = idx.local; i < bdim_len; i += nlocal) for(size_t i = idx.local; i < bdim_len / vec_size; i += nlocal)
{ {
buffer[i] = yp[i]; buffer[i] = yp[i];
} }
__syncthreads(); __syncthreads();
for(size_t i = idx.global; i < n; i += nglobal) for(size_t i = idx.global; i < n; i += nglobal)
{ {
auto bidx = i % bdim_vec_len;
auto b = buffer[bidx];
vec4<type> x = xp[i]; vec4<type> x = xp[i];
vec4<type> out = outp[i]; vec4<type> out = outp[i];
for(std::size_t j = 0; j < 4; j++) for(std::size_t j = 0; j < vec_size; j++)
{ {
auto gidx = i * 4 + j; out[j] = f(x[j], b[j]);
auto bidx = gidx % bdim_len;
auto b = buffer[bidx];
out[j] = f(x[j], b);
} }
outp[i] = out;
} }
}); });
}); });
...@@ -97,10 +105,26 @@ auto nary_standard(argument result, Arguments... args) ...@@ -97,10 +105,26 @@ auto nary_standard(argument result, Arguments... args)
// assert(x.get_shape().elements() == y.get_shape().elements()); // assert(x.get_shape().elements() == y.get_shape().elements());
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) { visit_all(result, args...)([&](auto output, auto... inputs) {
#if 1
auto data = pack(inputs.data()...); auto data = pack(inputs.data()...);
auto* outp = output.data(); auto* outp = output.data();
gs_launch(output_shape.elements())( gs_launch(output_shape.elements())(
[=](auto i) { data([&](auto... xps) { outp[i] = f(xps[i]...); }); }); [=](auto i) { data([&](auto... xps) { outp[i] = f(xps[i]...); }); });
#else
using type = std::remove_cv_t<typename decltype(output)::value_type>;
const std::size_t vec_size = 4;
auto data = pack_vec4(inputs.data()...);
auto* outp = as_vec4(output.data());
gs_launch(output_shape.elements() / vec_size)([=](auto i) {
vec4<type> out = outp[i];
data([&](auto... xs) {
for(std::size_t j = 0; j < vec_size; j++) {
out[j] = f(xs[j]...);
}
}, i);
outp[i] = out;
});
#endif
}); });
}; };
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment