Commit 7f05ac8a authored by Paul's avatar Paul
Browse files

Load memory into lds

parent 20bdf794
...@@ -39,33 +39,37 @@ auto nary_nonstandard(argument result, Arguments... args) ...@@ -39,33 +39,37 @@ auto nary_nonstandard(argument result, Arguments... args)
inline auto binary_broadcast(argument result, argument arg1, argument arg2) inline auto binary_broadcast(argument result, argument arg1, argument arg2)
{ {
return [=](auto f) { return [=](auto f) {
const auto& output_shape = result.get_shape();
const auto& b_shape = arg2.get_shape(); const auto& b_shape = arg2.get_shape();
auto bdim = std::distance(b_shape.strides().begin(), auto bdim = std::distance(b_shape.strides().begin(),
std::find_if(b_shape.strides().begin(), std::find_if(b_shape.strides().begin(),
b_shape.strides().end(), b_shape.strides().end(),
[](auto x) { return x != 0; })); [](auto x) { return x != 0; }));
auto bdim_len = b_shape.lens()[bdim]; auto bdim_len = b_shape.lens()[bdim];
auto outer_size = std::accumulate(output_shape.lens().begin(),
output_shape.lens().begin() + bdim + 1,
std::size_t{1},
std::multiplies<>{});
auto inner_size = std::accumulate(output_shape.lens().begin() + bdim + 1,
output_shape.lens().end(),
std::size_t{1},
std::multiplies<>{});
visit_all(result, arg1, arg2)([&](auto output, auto input1, auto input2) { visit_all(result, arg1, arg2)([&](auto output, auto input1, auto input2) {
using type = std::remove_cv_t<typename decltype(output)::value_type>;
auto* xp = input1.data(); auto* xp = input1.data();
auto* yp = input2.data(); auto* yp = input2.data();
auto* outp = output.data(); auto* outp = output.data();
gs_launch(outer_size)([=](auto i) {
auto* outp2 = outp + i; const std::size_t nlocal = 256;
auto* xp2 = xp + i; const std::size_t nglobal = 256 * nlocal;
auto b = yp[i % bdim_len]; const std::size_t n = output.size();
for(std::size_t j = 0; j < inner_size; j++)
launch(nglobal, nlocal)([=](auto idx) __device__ {
__shared__ type buffer[2048];
for(size_t i = idx.local; i < bdim_len; i += nlocal)
{
buffer[i] = yp[i];
}
__syncthreads();
for(size_t i = idx.local; i < bdim_len; i += nlocal)
{ {
outp2[j] = f(xp2[j], b); auto b = buffer[i];
for(size_t j = idx.global; j < n; j += nglobal)
{
outp[j] = f(xp[j], b);
}
} }
}); });
}); });
...@@ -114,6 +118,7 @@ inline auto nary(argument result, argument arg1, argument arg2) ...@@ -114,6 +118,7 @@ inline auto nary(argument result, argument arg1, argument arg2)
return [=](auto f) { return [=](auto f) {
// TODO: Check for one broadcast stride // TODO: Check for one broadcast stride
// TODO: Check result and arg1 shape is the same // TODO: Check result and arg1 shape is the same
// TODO: CHeck that broadcast shape doesnt have more than 2048 elements
if(arg1.get_shape().standard() and arg2.get_shape().broadcasted() and if(arg1.get_shape().standard() and arg2.get_shape().broadcasted() and
std::count_if(arg2.get_shape().strides().begin(), std::count_if(arg2.get_shape().strides().begin(),
arg2.get_shape().strides().end(), arg2.get_shape().strides().end(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment