Commit 063ba0c4 authored by Paul's avatar Paul
Browse files

Hacked fixes for pointwise

parent f449cd1d
......@@ -43,9 +43,9 @@ struct pointwise_compiler : compiler<pointwise_compiler>
static std::size_t oversubscribe(const std::vector<shape>& inputs)
{
if(std::any_of(inputs.begin(), inputs.end(), [](const auto& s) { return s.broadcasted(); }))
return 1;
else
// if(std::any_of(inputs.begin(), inputs.end(), [](const auto& s) { return s.broadcasted(); }))
// return 1;
// else
return 4;
}
static std::size_t vectorize_elements(const std::vector<shape>& inputs)
......
......@@ -114,7 +114,7 @@ __device__ auto preload(index idx, Ts... xs)
constexpr auto size = decltype(compute_preload_size<type>(make_shape_type(xs)...)){};
const index_int max_size = 512 * sizeof(type);
return [=](auto f) {
if constexpr(size > 0 and size < max_size)
if constexpr(size > 0 and size < max_size and false)
{
__shared__ type buffer[size];
preload_copy(idx, f, buffer, xs...);
......
......@@ -109,15 +109,15 @@ constexpr index_int find_vector_axis_c(Shape s)
template <class... Shapes>
constexpr index_int find_vector_axis_c(Shapes... ss)
{
const bool all_broadcasted = (ss.broadcasted() and ...);
// const bool all_broadcasted = (ss.broadcasted() and ...);
index_int axis = 0;
bool b = false;
by([&](auto s) {
if(b)
return;
// Skip broadcasted shapes if there are shapes not broadcasted
if(not all_broadcasted and s.broadcasted())
return;
// if(not all_broadcasted and s.broadcasted())
// return;
axis = find_vector_axis_c(s);
if(s.strides[axis] == 1)
b = true;
......@@ -139,7 +139,7 @@ constexpr auto is_vectorizable_c(Axis axis, Shapes... ss)
return ((axis < ss.lens.size() and ss.lens[axis] % N == 0 and
// Only vectorize broadcasted types with stride 0, since this causes issues in the
// preloader
((not ss.broadcasted() and ss.strides[axis] == 1) or ss.strides[axis] == 0)) and
((ss.strides[axis] == 1) or ss.strides[axis] == 0)) and
...);
}
......@@ -152,9 +152,10 @@ constexpr auto is_vectorizable(Axis, Shapes...)
template <class P>
constexpr auto find_vectorize_size(P pred)
{
if constexpr(decltype(pred(_c<4>)){})
return _c<4>;
else if constexpr(decltype(pred(_c<2>)){})
// if constexpr(decltype(pred(_c<4>)){})
// return _c<4>;
// else
if constexpr(decltype(pred(_c<2>)){})
return _c<2>;
else
return _c<0>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment