Commit 6459204c authored by Paul's avatar Paul
Browse files

Fix broadcast tensor op

parent 3b04798c
...@@ -80,6 +80,7 @@ std::vector<T> generate_tensor_data(const migraph::shape& s, unsigned long seed ...@@ -80,6 +80,7 @@ std::vector<T> generate_tensor_data(const migraph::shape& s, unsigned long seed
std::vector<T> result(s.elements()); std::vector<T> result(s.elements());
std::generate(result.begin(), result.end(), xorshf96_generator<T>{seed}); std::generate(result.begin(), result.end(), xorshf96_generator<T>{seed});
// std::generate(result.begin(), result.end(), [&]{ return seed % 7; }); // std::generate(result.begin(), result.end(), [&]{ return seed % 7; });
// std::generate(result.begin(), result.end(), []{ return 1; });
return result; return result;
} }
......
...@@ -8,7 +8,7 @@ namespace device { ...@@ -8,7 +8,7 @@ namespace device {
void add_relu(argument result, argument arg1, argument arg2) void add_relu(argument result, argument arg1, argument arg2)
{ {
nary(std::move(result), std::move(arg1), std::move(arg2))( nary(std::move(result), std::move(arg1), std::move(arg2))(
[](auto x, auto y) { return max(0, x + y); }); [](auto x, auto y) { return std::max<decltype(x + y)>(0, x + y); });
} }
} // namespace device } // namespace device
......
...@@ -14,25 +14,15 @@ template <class T> ...@@ -14,25 +14,15 @@ template <class T>
using vec4 = T __attribute__((ext_vector_type(4))); using vec4 = T __attribute__((ext_vector_type(4)));
template <class T> template <class T>
vec4<T>* as_vec4(T* x) __device__ __host__ vec4<T>* as_vec4(T* x)
{ {
std::uintptr_t a = reinterpret_cast<std::uintptr_t>(x);
if(a % 32 != 0)
throw std::runtime_error("Memory not aligned for vector operations");
return reinterpret_cast<vec4<T>*>(x); return reinterpret_cast<vec4<T>*>(x);
// return (vec4<T>*)(x);
} }
template <class T> template <class T>
vec4<T> vec4_load(T* x, size_t i) __device__ __host__ T* as_pointer(vec4<T>* x)
{ {
vec4<T> result; return reinterpret_cast<T*>(x);
auto n = i * 4;
result[0] = x[n + 0];
result[1] = x[n + 1];
result[2] = x[n + 2];
result[3] = x[n + 3];
return result;
} }
template <class... Ts> template <class... Ts>
...@@ -67,46 +57,21 @@ auto nary_nonstandard(argument result, Arguments... args) ...@@ -67,46 +57,21 @@ auto nary_nonstandard(argument result, Arguments... args)
return [=](auto f) { return nary_nonstandard_impl(f, result, args...); }; return [=](auto f) { return nary_nonstandard_impl(f, result, args...); };
} }
inline auto binary_broadcast(argument result, argument arg1, argument arg2) inline auto binary_broadcast_vec(argument result, argument arg1, argument arg2)
{ {
return [=](auto f) { return [=](auto f) {
// const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
const auto& b_shape = arg2.get_shape(); const auto& b_shape = arg2.get_shape();
auto bdim = std::distance(b_shape.strides().begin(), auto bdim = std::distance(b_shape.strides().begin(),
std::find_if(b_shape.strides().begin(), std::find_if(b_shape.strides().begin(),
b_shape.strides().end(), b_shape.strides().end(),
[](auto x) { return x != 0; })); [](auto x) { return x != 0; }));
auto bdim_len = b_shape.lens()[bdim]; auto bdim_len = output_shape.lens()[bdim];
auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len;
visit_all(result, arg1, arg2)([&](auto output, auto input1, auto input2) { visit_all(result, arg1, arg2)([&](auto output, auto input1, auto input2) {
using type = std::remove_cv_t<typename decltype(output)::value_type>; using type = std::remove_cv_t<typename decltype(output)::value_type>;
#if 1
auto* xp = input1.data();
auto* yp = input2.data();
auto* outp = output.data();
const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal;
const std::size_t n = output.size();
launch(nglobal, nlocal)([=](auto idx) __device__ {
__shared__ type buffer[2048];
// Load bias into LDS
for(size_t i = idx.local; i < bdim_len; i += nlocal)
{
buffer[i] = yp[i];
}
__syncthreads();
// Process the data
for(size_t i = idx.global; i < n; i += nglobal)
{
auto bidx = i % bdim_len;
auto b = buffer[bidx];
type x = xp[i];
outp[i] = f(x, b);
}
});
#else
auto* xp = as_vec4(input1.data()); auto* xp = as_vec4(input1.data());
auto* yp = as_vec4(input2.data()); auto* yp = as_vec4(input2.data());
auto* outp = as_vec4(output.data()); auto* outp = as_vec4(output.data());
...@@ -115,63 +80,86 @@ inline auto binary_broadcast(argument result, argument arg1, argument arg2) ...@@ -115,63 +80,86 @@ inline auto binary_broadcast(argument result, argument arg1, argument arg2)
const std::size_t nlocal = 1024; const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal; const std::size_t nglobal = 256 * nlocal;
const std::size_t n = output.size() / vec_size; const std::size_t n = output.size() / vec_size;
const std::size_t rem = output.size() % vec_size;
const std::size_t bdim_vec_len = bdim_len / vec_size; const std::size_t bdim_vec_len = bdim_len / vec_size;
const std::size_t bdim_vec_rem = bdim_len % vec_size;
launch(nglobal, nlocal)([=](auto idx) __device__ { launch(nglobal, nlocal)([=](auto idx) __device__ {
__shared__ vec4<type> buffer[2048]; __shared__ vec4<type> buffer[2048 / vec_size];
// Load bias into LDS // Load bias into LDS
for(size_t i = idx.local; i < bdim_vec_len; i += nlocal) for(size_t i = idx.local; i < bdim_vec_len; i += nlocal)
{ {
buffer[i] = yp[i]; buffer[i] = yp[i];
} }
// Load remaining bias data
for(size_t i = idx.local; i < bdim_vec_rem; i += nlocal)
{
buffer[bdim_vec_len][i] = yp[bdim_vec_len][i];
}
for(size_t i = idx.local; i < (vec_size - bdim_vec_rem); i += nlocal)
{
buffer[bdim_vec_len][i] = yp[0][i];
}
__syncthreads(); __syncthreads();
auto * bp = as_pointer(buffer);
// Process the data // Process the data
for(size_t i = idx.global; i < n; i += nglobal) for(size_t i = idx.global; i < n; i += nglobal)
{ {
auto bidx = bdim_vec_len == 0 ? 0 : i % bdim_vec_len; auto bidx = ((i * vec_size) % bdim_next_stride) / bdim_stride;
auto b = buffer[bidx]; auto b = bp[bidx];
vec4<type> x = xp[i]; vec4<type> x = xp[i];
vec4<type> out = outp[i]; vec4<type> out = outp[i];
for(std::size_t j = 0; j < vec_size; j++) for(std::size_t j = 0; j < vec_size; j++)
{ {
out[j] = f(x[j], b[j]); out[j] = f(x[j], b);
} }
outp[i] = out; outp[i] = out;
} }
for(size_t i = idx.global; i < rem; i += nglobal) });
});
};
}
inline auto binary_broadcast(argument result, argument arg1, argument arg2)
{
return [=](auto f) {
const auto& output_shape = result.get_shape();
const auto& b_shape = arg2.get_shape();
auto bdim = std::distance(b_shape.strides().begin(),
std::find_if(b_shape.strides().begin(),
b_shape.strides().end(),
[](auto x) { return x != 0; }));
auto bdim_len = output_shape.lens()[bdim];
auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len;
visit_all(result, arg1, arg2)([&](auto output, auto input1, auto input2) {
using type = std::remove_cv_t<typename decltype(output)::value_type>;
auto* xp = input1.data();
auto* yp = input2.data();
auto* outp = output.data();
const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal;
const std::size_t n = output.size();
launch(nglobal, nlocal)([=](auto idx) __device__ {
__shared__ type buffer[2048];
// Load bias into LDS
for(size_t i = idx.local; i < bdim_len; i += nlocal)
{
buffer[i] = yp[i];
}
__syncthreads();
// Process the data
for(size_t i = idx.global; i < n; i += nglobal)
{ {
outp[n][i] = f(xp[n][i], buffer[bdim_vec_len][i]); auto bidx = (i % bdim_next_stride) / bdim_stride;
auto b = buffer[bidx];
type x = xp[i];
outp[i] = f(x, b);
} }
}); });
#endif
}); });
}; };
} }
template <class... Arguments> template <class... Arguments>
auto nary_standard(argument result, Arguments... args) auto nary_standard_vec(argument result, Arguments... args)
{ {
return [=](auto f) { return [=](auto f) {
// assert(x.get_shape().elements() == y.get_shape().elements()); // assert(x.get_shape().elements() == y.get_shape().elements());
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) { visit_all(result, args...)([&](auto output, auto... inputs) {
#if 1
auto data = pack(inputs.data()...);
auto* outp = output.data();
gs_launch(output_shape.elements())(
[=](auto i) { data([&](auto... xps) { outp[i] = f(xps[i]...); }); });
#else
using type = std::remove_cv_t<typename decltype(output)::value_type>; using type = std::remove_cv_t<typename decltype(output)::value_type>;
const std::size_t vec_size = 4; const std::size_t vec_size = 4;
auto data = pack_vec4(inputs.data()...); auto data = pack_vec4(inputs.data()...);
...@@ -188,7 +176,21 @@ auto nary_standard(argument result, Arguments... args) ...@@ -188,7 +176,21 @@ auto nary_standard(argument result, Arguments... args)
i); i);
outp[i] = out; outp[i] = out;
}); });
#endif });
};
}
template <class... Arguments>
auto nary_standard(argument result, Arguments... args)
{
return [=](auto f) {
// assert(x.get_shape().elements() == y.get_shape().elements());
const auto& output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) {
auto data = pack(inputs.data()...);
auto* outp = output.data();
gs_launch(output_shape.elements())(
[=](auto i) { data([&](auto... xps) { outp[i] = f(xps[i]...); }); });
}); });
}; };
} }
...@@ -219,22 +221,23 @@ inline auto nary(argument result, argument arg1, argument arg2) ...@@ -219,22 +221,23 @@ inline auto nary(argument result, argument arg1, argument arg2)
{ {
return [=](auto f) { return [=](auto f) {
// TODO: Check result and arg1 shape is the same // TODO: Check result and arg1 shape is the same
if(arg1.get_shape().standard() and arg2.get_shape().broadcasted() and if(arg1.get_shape().standard() and arg2.get_shape().broadcasted())
std::count_if(arg2.get_shape().strides().begin(),
arg2.get_shape().strides().end(),
[](auto x) { return x != 0; }) == 1)
{ {
auto not_zero = [](auto x) { return x != 0; }; auto not_zero = [](auto x) { return x != 0; };
const auto& strides = arg2.get_shape().strides(); const auto& strides = arg2.get_shape().strides();
auto stride_it = std::find_if(strides.begin(), strides.end(), not_zero); auto b_it = std::find_if(strides.begin(), strides.end(), not_zero);
auto stride_idx = std::distance(strides.begin(), stride_it); auto b_idx = std::distance(strides.begin(), b_it);
auto stride_len = arg2.get_shape().lens()[stride_idx]; auto b_len = result.get_shape().lens()[b_idx];
// TODO: Dont require disibility by 4 auto b_stride = result.get_shape().strides()[b_idx];
bool divisible_by_4 = (stride_len % 4 == 0) and (arg1.get_shape().elements() % 4 == 0); assert(arg2.get_shape().lens()[b_idx] == b_len);
if(divisible_by_4 and stride_len <= 2048 and if(b_len <= 2048 and
std::none_of(std::next(stride_it), strides.end(), not_zero)) std::none_of(std::next(b_it), strides.end(), not_zero))
{ {
binary_broadcast(result, arg1, arg2)(f); const bool divisible_by_4 = (b_len % 4 == 0) and (b_stride % 4 == 0) and (arg1.get_shape().elements() % 4 == 0);
if(divisible_by_4)
binary_broadcast_vec(result, arg1, arg2)(f);
else
binary_broadcast(result, arg1, arg2)(f);
return; return;
} }
} }
......
...@@ -260,6 +260,20 @@ struct test_add_broadcast4 ...@@ -260,6 +260,20 @@ struct test_add_broadcast4
} }
}; };
struct test_add_broadcast5
{
migraph::program create_program() const
{
migraph::program p;
migraph::shape s{migraph::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraph::shape::float_type, {2, 4, 8}});
auto y = p.add_parameter("y", {migraph::shape::float_type, {4}});
auto by = p.add_instruction(migraph::broadcast{1}, x, y);
p.add_instruction(migraph::add{}, x, by);
return p;
}
};
struct test_conv_relu struct test_conv_relu
{ {
migraph::program create_program() const migraph::program create_program() const
...@@ -471,6 +485,7 @@ int main() ...@@ -471,6 +485,7 @@ int main()
verify_program<test_add_broadcast2>(); verify_program<test_add_broadcast2>();
verify_program<test_add_broadcast3>(); verify_program<test_add_broadcast3>();
verify_program<test_add_broadcast4>(); verify_program<test_add_broadcast4>();
verify_program<test_add_broadcast5>();
verify_program<test_conv_relu>(); verify_program<test_conv_relu>();
verify_program<test_add_relu>(); verify_program<test_add_relu>();
verify_program<test_conv_pooling>(); verify_program<test_conv_pooling>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment