Commit 063ba0c4 authored by Paul's avatar Paul
Browse files

Hacked fixes for pointwise

parent f449cd1d
...@@ -43,9 +43,9 @@ struct pointwise_compiler : compiler<pointwise_compiler> ...@@ -43,9 +43,9 @@ struct pointwise_compiler : compiler<pointwise_compiler>
static std::size_t oversubscribe(const std::vector<shape>& inputs) static std::size_t oversubscribe(const std::vector<shape>& inputs)
{ {
if(std::any_of(inputs.begin(), inputs.end(), [](const auto& s) { return s.broadcasted(); })) // if(std::any_of(inputs.begin(), inputs.end(), [](const auto& s) { return s.broadcasted(); }))
return 1; // return 1;
else // else
return 4; return 4;
} }
static std::size_t vectorize_elements(const std::vector<shape>& inputs) static std::size_t vectorize_elements(const std::vector<shape>& inputs)
......
...@@ -114,7 +114,7 @@ __device__ auto preload(index idx, Ts... xs) ...@@ -114,7 +114,7 @@ __device__ auto preload(index idx, Ts... xs)
constexpr auto size = decltype(compute_preload_size<type>(make_shape_type(xs)...)){}; constexpr auto size = decltype(compute_preload_size<type>(make_shape_type(xs)...)){};
const index_int max_size = 512 * sizeof(type); const index_int max_size = 512 * sizeof(type);
return [=](auto f) { return [=](auto f) {
if constexpr(size > 0 and size < max_size) if constexpr(size > 0 and size < max_size and false)
{ {
__shared__ type buffer[size]; __shared__ type buffer[size];
preload_copy(idx, f, buffer, xs...); preload_copy(idx, f, buffer, xs...);
......
...@@ -109,15 +109,15 @@ constexpr index_int find_vector_axis_c(Shape s) ...@@ -109,15 +109,15 @@ constexpr index_int find_vector_axis_c(Shape s)
template <class... Shapes> template <class... Shapes>
constexpr index_int find_vector_axis_c(Shapes... ss) constexpr index_int find_vector_axis_c(Shapes... ss)
{ {
const bool all_broadcasted = (ss.broadcasted() and ...); // const bool all_broadcasted = (ss.broadcasted() and ...);
index_int axis = 0; index_int axis = 0;
bool b = false; bool b = false;
by([&](auto s) { by([&](auto s) {
if(b) if(b)
return; return;
// Skip broadcasted shapes if there are shapes not broadcasted // Skip broadcasted shapes if there are shapes not broadcasted
if(not all_broadcasted and s.broadcasted()) // if(not all_broadcasted and s.broadcasted())
return; // return;
axis = find_vector_axis_c(s); axis = find_vector_axis_c(s);
if(s.strides[axis] == 1) if(s.strides[axis] == 1)
b = true; b = true;
...@@ -139,7 +139,7 @@ constexpr auto is_vectorizable_c(Axis axis, Shapes... ss) ...@@ -139,7 +139,7 @@ constexpr auto is_vectorizable_c(Axis axis, Shapes... ss)
return ((axis < ss.lens.size() and ss.lens[axis] % N == 0 and return ((axis < ss.lens.size() and ss.lens[axis] % N == 0 and
// Only vectorize broadcasted types with stride 0, since this causes issues in the // Only vectorize broadcasted types with stride 0, since this causes issues in the
// preloader // preloader
((not ss.broadcasted() and ss.strides[axis] == 1) or ss.strides[axis] == 0)) and ((ss.strides[axis] == 1) or ss.strides[axis] == 0)) and
...); ...);
} }
...@@ -152,9 +152,10 @@ constexpr auto is_vectorizable(Axis, Shapes...) ...@@ -152,9 +152,10 @@ constexpr auto is_vectorizable(Axis, Shapes...)
template <class P> template <class P>
constexpr auto find_vectorize_size(P pred) constexpr auto find_vectorize_size(P pred)
{ {
if constexpr(decltype(pred(_c<4>)){}) // if constexpr(decltype(pred(_c<4>)){})
return _c<4>; // return _c<4>;
else if constexpr(decltype(pred(_c<2>)){}) // else
if constexpr(decltype(pred(_c<2>)){})
return _c<2>; return _c<2>;
else else
return _c<0>; return _c<0>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment