Commit 1ad95e66 authored by Paul's avatar Paul
Browse files

Formatting

parent bb666690
...@@ -51,69 +51,74 @@ auto nary_nonstandard_impl(F f, argument result, Arguments... args) ...@@ -51,69 +51,74 @@ auto nary_nonstandard_impl(F f, argument result, Arguments... args)
}); });
} }
template<class F> template <class F>
void binary_broadcast_vec_impl(F f, const argument& result, const argument& arg1, const argument& arg2) void binary_broadcast_vec_impl(F f,
const argument& result,
const argument& arg1,
const argument& arg2)
{ {
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
const auto& b_shape = arg2.get_shape(); const auto& b_shape = arg2.get_shape();
auto bdim = std::distance(b_shape.strides().begin(), auto bdim =
std::find_if(b_shape.strides().begin(), std::distance(b_shape.strides().begin(),
b_shape.strides().end(), std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) {
[](auto x) { return x != 0; })); return x != 0;
auto bdim_len = output_shape.lens()[bdim]; }));
auto bdim_stride = output_shape.strides()[bdim]; auto bdim_len = output_shape.lens()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len; auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len;
visit_all(result, arg1, arg2)([&](auto output, auto input1, auto input2) { visit_all(result, arg1, arg2)([&](auto output, auto input1, auto input2) {
using type = std::remove_cv_t<typename decltype(output)::value_type>; using type = std::remove_cv_t<typename decltype(output)::value_type>;
auto* xp = as_vec4(input1.data()); auto* xp = as_vec4(input1.data());
auto* yp = as_vec4(input2.data()); auto* yp = as_vec4(input2.data());
auto* outp = as_vec4(output.data()); auto* outp = as_vec4(output.data());
const std::size_t vec_size = 4; const std::size_t vec_size = 4;
const std::size_t nlocal = 1024; const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal; const std::size_t nglobal = 256 * nlocal;
const std::size_t n = output.size() / vec_size; const std::size_t n = output.size() / vec_size;
const std::size_t bdim_vec_len = bdim_len / vec_size; const std::size_t bdim_vec_len = bdim_len / vec_size;
launch(nglobal, nlocal)([=](auto idx) __device__ { launch(nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPH_DEVICE_SHARED vec4<type> buffer[2048 / vec_size]; MIGRAPH_DEVICE_SHARED vec4<type> buffer[2048 / vec_size];
// Load bias into LDS // Load bias into LDS
for(size_t i = idx.local; i < bdim_vec_len; i += nlocal) for(size_t i = idx.local; i < bdim_vec_len; i += nlocal)
{ {
buffer[i] = yp[i]; buffer[i] = yp[i];
} }
__syncthreads(); __syncthreads();
auto* bp = as_pointer(buffer); auto* bp = as_pointer(buffer);
// Process the data // Process the data
for(size_t i = idx.global; i < n; i += nglobal) for(size_t i = idx.global; i < n; i += nglobal)
{
auto bidx = ((i * vec_size) % bdim_next_stride) / bdim_stride;
auto b = bp[bidx];
vec4<type> x = xp[i];
vec4<type> out = outp[i];
for(std::size_t j = 0; j < vec_size; j++)
{ {
auto bidx = ((i * vec_size) % bdim_next_stride) / bdim_stride; out[j] = f(x[j], b);
auto b = bp[bidx];
vec4<type> x = xp[i];
vec4<type> out = outp[i];
for(std::size_t j = 0; j < vec_size; j++)
{
out[j] = f(x[j], b);
}
outp[i] = out;
} }
}); outp[i] = out;
}
}); });
});
} }
template<class F> template <class F>
void binary_broadcast_impl(F f, const argument& result, const argument& arg1, const argument& arg2) void binary_broadcast_impl(F f, const argument& result, const argument& arg1, const argument& arg2)
{ {
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
const auto& b_shape = arg2.get_shape(); const auto& b_shape = arg2.get_shape();
auto bdim = std::distance(b_shape.strides().begin(), auto bdim =
std::find_if(b_shape.strides().begin(), std::distance(b_shape.strides().begin(),
b_shape.strides().end(), std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) {
[](auto x) { return x != 0; })); return x != 0;
auto bdim_len = output_shape.lens()[bdim]; }));
auto bdim_stride = output_shape.strides()[bdim]; auto bdim_len = output_shape.lens()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len; auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len;
visit_all(result, arg1, arg2)([&](auto output, auto input1, auto input2) { visit_all(result, arg1, arg2)([&](auto output, auto input1, auto input2) {
using type = std::remove_cv_t<typename decltype(output)::value_type>; using type = std::remove_cv_t<typename decltype(output)::value_type>;
...@@ -148,77 +153,70 @@ void binary_broadcast_impl(F f, const argument& result, const argument& arg1, co ...@@ -148,77 +153,70 @@ void binary_broadcast_impl(F f, const argument& result, const argument& arg1, co
template <class F, class... Arguments> template <class F, class... Arguments>
void nary_standard_vec_impl(F f, argument result, Arguments... args) void nary_standard_vec_impl(F f, argument result, Arguments... args)
{ {
// assert(x.get_shape().elements() == y.get_shape().elements()); // assert(x.get_shape().elements() == y.get_shape().elements());
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) { visit_all(result, args...)([&](auto output, auto... inputs) {
using type = std::remove_cv_t<typename decltype(output)::value_type>; using type = std::remove_cv_t<typename decltype(output)::value_type>;
const std::size_t vec_size = 4; const std::size_t vec_size = 4;
auto data = pack_vec4(inputs.data()...); auto data = pack_vec4(inputs.data()...);
auto* outp = as_vec4(output.data()); auto* outp = as_vec4(output.data());
gs_launch(output_shape.elements() / vec_size)([=](auto i) { gs_launch(output_shape.elements() / vec_size)([=](auto i) {
vec4<type> out = outp[i]; vec4<type> out = outp[i];
data( data(
[&](auto... xs) { [&](auto... xs) {
for(std::size_t j = 0; j < vec_size; j++) for(std::size_t j = 0; j < vec_size; j++)
{ {
out[j] = f(xs[j]...); out[j] = f(xs[j]...);
} }
}, },
i); i);
outp[i] = out; outp[i] = out;
});
}); });
});
} }
template <class F, class... Arguments> template <class F, class... Arguments>
void nary_standard_impl(F f, argument result, Arguments... args) void nary_standard_impl(F f, argument result, Arguments... args)
{ {
// assert(x.get_shape().elements() == y.get_shape().elements()); // assert(x.get_shape().elements() == y.get_shape().elements());
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) { visit_all(result, args...)([&](auto output, auto... inputs) {
auto data = pack(inputs.data()...); auto data = pack(inputs.data()...);
auto* outp = output.data(); auto* outp = output.data();
gs_launch(output_shape.elements())( gs_launch(output_shape.elements())(
[=](auto i) { data([&](auto... xps) { outp[i] = f(xps[i]...); }); }); [=](auto i) { data([&](auto... xps) { outp[i] = f(xps[i]...); }); });
}); });
} }
template <class F, class... Arguments> template <class F, class... Arguments>
void nary_impl(F f, argument result, Arguments... args) void nary_impl(F f, argument result, Arguments... args)
{ {
bool standard = all_of({args.get_shape()...}, [](const shape& s) { return s.standard(); }); bool standard = all_of({args.get_shape()...}, [](const shape& s) { return s.standard(); });
bool packed = all_of({args.get_shape()...}, [](const shape& s) { return s.packed(); }); bool packed = all_of({args.get_shape()...}, [](const shape& s) { return s.packed(); });
bool same_shapes = bool same_shapes =
all_of({args.get_shape()...}, [&](const shape& s) { return s == result.get_shape(); }); all_of({args.get_shape()...}, [&](const shape& s) { return s == result.get_shape(); });
if(standard or (packed and same_shapes)) if(standard or (packed and same_shapes))
nary_standard_impl(f, result, args...); nary_standard_impl(f, result, args...);
else else
nary_nonstandard_impl(f, result, args...); nary_nonstandard_impl(f, result, args...);
} }
template <class... Arguments> template <class... Arguments>
auto nary_nonstandard(argument result, Arguments... args) auto nary_nonstandard(argument result, Arguments... args)
{ {
return [=](auto f) { return [=](auto f) { nary_nonstandard_impl(f, result, args...); };
nary_nonstandard_impl(f, result, args...);
};
} }
template <class... Arguments> template <class... Arguments>
auto nary_standard(argument result, Arguments... args) auto nary_standard(argument result, Arguments... args)
{ {
return [=](auto f) { return [=](auto f) { nary_standard_impl(f, result, args...); };
nary_standard_impl(f, result, args...);
};
} }
template <class... Arguments> template <class... Arguments>
auto nary(argument result, Arguments... args) auto nary(argument result, Arguments... args)
{ {
return [=](auto f) { return [=](auto f) { nary_impl(f, result, args...); };
nary_impl(f, result, args...);
};
} }
inline auto nary(const argument& result, const argument& arg1, const argument& arg2) inline auto nary(const argument& result, const argument& arg1, const argument& arg2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment