Commit 270194c4 authored by Paul's avatar Paul
Browse files

Add triadd

parent 32396d8f
......@@ -10,6 +10,11 @@ void add(const argument& result, const argument& arg1, const argument& arg2)
nary(result, arg1, arg2)([](auto x, auto y) { return x + y; });
}
void add(const argument& result, const argument& arg1, const argument& arg2, const argument& arg3)
{
nary(result, arg1, arg2, arg3)([](auto x, auto y, auto z) { return x + y + z; });
}
} // namespace device
} // namespace gpu
} // namespace migraph
......@@ -10,6 +10,15 @@ void add_relu(const argument& result, const argument& arg1, const argument& arg2
nary(result, arg1, arg2)([](auto x, auto y) { return std::max<decltype(x + y)>(0, x + y); });
}
void add_relu(const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3)
{
nary(result, arg1, arg2, arg3)(
[](auto x, auto y, auto z) { return std::max<decltype(x + y + z)>(0, x + y + z); });
}
} // namespace device
} // namespace gpu
} // namespace migraph
......@@ -51,6 +51,108 @@ auto nary_nonstandard_impl(F f, argument result, Arguments... args)
});
}
template <class F>
void trinary_broadcast_vec_impl(
F f, const argument& result, const argument& arg1, const argument& arg2, const argument& arg3)
{
const auto& output_shape = result.get_shape();
const auto& b_shape = arg3.get_shape();
auto bdim =
std::distance(b_shape.strides().begin(),
std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) {
return x != 0;
}));
auto bdim_len = output_shape.lens()[bdim];
auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len;
visit_all(result, arg1, arg2, arg3)([&](auto output, auto input1, auto input2, auto input3) {
using type = std::remove_cv_t<typename decltype(output)::value_type>;
auto* xp = as_vec4(input1.data());
auto* yp = as_vec4(input2.data());
auto* zp = as_vec4(input3.data());
auto* outp = as_vec4(output.data());
const std::size_t vec_size = 4;
const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal;
const std::size_t n = output.size() / vec_size;
const std::size_t bdim_vec_len = bdim_len / vec_size;
launch(nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPH_DEVICE_SHARED vec4<type> buffer[2048 / vec_size];
// Load bias into LDS
for(size_t i = idx.local; i < bdim_vec_len; i += nlocal)
{
buffer[i] = zp[i];
}
__syncthreads();
auto* bp = as_pointer(buffer);
// Process the data
for(size_t i = idx.global; i < n; i += nglobal)
{
auto bidx = ((i * vec_size) % bdim_next_stride) / bdim_stride;
auto b = bp[bidx];
vec4<type> x = xp[i];
vec4<type> y = yp[i];
vec4<type> out = outp[i];
for(std::size_t j = 0; j < vec_size; j++)
{
out[j] = f(x[j], y[j], b);
}
outp[i] = out;
}
});
});
}
template <class F>
void trinary_broadcast_impl(
F f, const argument& result, const argument& arg1, const argument& arg2, const argument& arg3)
{
const auto& output_shape = result.get_shape();
const auto& b_shape = arg3.get_shape();
auto bdim =
std::distance(b_shape.strides().begin(),
std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) {
return x != 0;
}));
auto bdim_len = output_shape.lens()[bdim];
auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len;
visit_all(result, arg1, arg2)([&](auto output, auto input1, auto input2) {
using type = std::remove_cv_t<typename decltype(output)::value_type>;
auto* xp = input1.data();
auto* yp = input2.data();
auto* zp = input2.data();
auto* outp = output.data();
const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal;
const std::size_t n = output.size();
launch(nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPH_DEVICE_SHARED type buffer[2048];
// Load bias into LDS
for(size_t i = idx.local; i < bdim_len; i += nlocal)
{
buffer[i] = zp[i];
}
__syncthreads();
// Process the data
for(size_t i = idx.global; i < n; i += nglobal)
{
auto bidx = (i % bdim_next_stride) / bdim_stride;
auto b = buffer[bidx];
type x = xp[i];
type y = yp[i];
outp[i] = f(x, y, b);
}
});
});
}
template <class F>
void binary_broadcast_vec_impl(F f,
const argument& result,
......@@ -247,6 +349,36 @@ inline auto nary(const argument& result, const argument& arg1, const argument& a
};
}
inline auto
nary(const argument& result, const argument& arg1, const argument& arg2, const argument& arg3)
{
return [=](auto f) {
// TODO: Check result and arg1 shape is the same
if(arg1.get_shape().standard() and arg2.get_shape().standard() and
arg3.get_shape().broadcasted())
{
auto not_zero = [](auto x) { return x != 0; };
const auto& strides = arg3.get_shape().strides();
auto b_it = std::find_if(strides.begin(), strides.end(), not_zero);
auto b_idx = std::distance(strides.begin(), b_it);
auto b_len = result.get_shape().lens()[b_idx];
auto b_stride = result.get_shape().strides()[b_idx];
assert(arg3.get_shape().lens()[b_idx] == b_len);
if(b_len <= 2048 and std::none_of(std::next(b_it), strides.end(), not_zero))
{
const bool divisible_by_4 = (b_len % 4 == 0) and (b_stride % 4 == 0) and
(arg1.get_shape().elements() % 4 == 0);
if(divisible_by_4)
trinary_broadcast_vec_impl(f, result, arg1, arg2, arg3);
else
trinary_broadcast_impl(f, result, arg1, arg2, arg3);
return;
}
}
nary_impl(f, result, arg1, arg2, arg3);
};
}
} // namespace device
} // namespace gpu
} // namespace migraph
......
......@@ -10,6 +10,8 @@ namespace device {
void add(const argument& result, const argument& arg1, const argument& arg2);
void add(const argument& result, const argument& arg1, const argument& arg2, const argument& arg3);
} // namespace device
} // namespace gpu
} // namespace migraph
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment