"src/vscode:/vscode.git/clone" did not exist on "3f1411767bc0f1837adb6f289713807f18599db3"
Commit 3b04798c authored by Paul's avatar Paul
Browse files

Formatting

parent fbcb4570
...@@ -12,9 +12,9 @@ constexpr T normalize(unsigned long z) ...@@ -12,9 +12,9 @@ constexpr T normalize(unsigned long z)
{ {
if(z == 0) if(z == 0)
return 0; return 0;
const auto max = 32768; const auto max = 32768;
const double range = max / 2; const double range = max / 2;
double result = (z % max) / range; double result = (z % max) / range;
result -= 1; result -= 1;
return result; return result;
} }
......
...@@ -27,7 +27,7 @@ template <class T> ...@@ -27,7 +27,7 @@ template <class T>
vec4<T> vec4_load(T* x, size_t i) vec4<T> vec4_load(T* x, size_t i)
{ {
vec4<T> result; vec4<T> result;
auto n = i * 4; auto n = i * 4;
result[0] = x[n + 0]; result[0] = x[n + 0];
result[1] = x[n + 1]; result[1] = x[n + 1];
result[2] = x[n + 2]; result[2] = x[n + 2];
...@@ -85,9 +85,9 @@ inline auto binary_broadcast(argument result, argument arg1, argument arg2) ...@@ -85,9 +85,9 @@ inline auto binary_broadcast(argument result, argument arg1, argument arg2)
auto* yp = input2.data(); auto* yp = input2.data();
auto* outp = output.data(); auto* outp = output.data();
const std::size_t nlocal = 1024; const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal; const std::size_t nglobal = 256 * nlocal;
const std::size_t n = output.size(); const std::size_t n = output.size();
launch(nglobal, nlocal)([=](auto idx) __device__ { launch(nglobal, nlocal)([=](auto idx) __device__ {
__shared__ type buffer[2048]; __shared__ type buffer[2048];
...@@ -100,10 +100,10 @@ inline auto binary_broadcast(argument result, argument arg1, argument arg2) ...@@ -100,10 +100,10 @@ inline auto binary_broadcast(argument result, argument arg1, argument arg2)
// Process the data // Process the data
for(size_t i = idx.global; i < n; i += nglobal) for(size_t i = idx.global; i < n; i += nglobal)
{ {
auto bidx = i % bdim_len; auto bidx = i % bdim_len;
auto b = buffer[bidx]; auto b = buffer[bidx];
type x = xp[i]; type x = xp[i];
outp[i] = f(x, b); outp[i] = f(x, b);
} }
}); });
#else #else
...@@ -131,7 +131,7 @@ inline auto binary_broadcast(argument result, argument arg1, argument arg2) ...@@ -131,7 +131,7 @@ inline auto binary_broadcast(argument result, argument arg1, argument arg2)
{ {
buffer[bdim_vec_len][i] = yp[bdim_vec_len][i]; buffer[bdim_vec_len][i] = yp[bdim_vec_len][i];
} }
for(size_t i = idx.local; i < (vec_size-bdim_vec_rem); i += nlocal) for(size_t i = idx.local; i < (vec_size - bdim_vec_rem); i += nlocal)
{ {
buffer[bdim_vec_len][i] = yp[0][i]; buffer[bdim_vec_len][i] = yp[0][i];
} }
...@@ -224,15 +224,16 @@ inline auto nary(argument result, argument arg1, argument arg2) ...@@ -224,15 +224,16 @@ inline auto nary(argument result, argument arg1, argument arg2)
arg2.get_shape().strides().end(), arg2.get_shape().strides().end(),
[](auto x) { return x != 0; }) == 1) [](auto x) { return x != 0; }) == 1)
{ {
auto not_zero = [](auto x) { return x != 0; }; auto not_zero = [](auto x) { return x != 0; };
const auto& strides = arg2.get_shape().strides(); const auto& strides = arg2.get_shape().strides();
auto stride_it = std::find_if(strides.begin(), auto stride_it = std::find_if(strides.begin(), strides.end(), not_zero);
strides.end(), not_zero); auto stride_idx = std::distance(strides.begin(), stride_it);
auto stride_idx = std::distance(strides.begin(), stride_it); auto stride_len = arg2.get_shape().lens()[stride_idx];
auto stride_len = arg2.get_shape().lens()[stride_idx];
// TODO: Dont require disibility by 4 // TODO: Dont require disibility by 4
bool divisible_by_4 = (stride_len % 4 == 0) and (arg1.get_shape().elements() % 4 == 0); bool divisible_by_4 = (stride_len % 4 == 0) and (arg1.get_shape().elements() % 4 == 0);
if(divisible_by_4 and stride_len <= 2048 and std::none_of(std::next(stride_it), strides.end(), not_zero)) { if(divisible_by_4 and stride_len <= 2048 and
std::none_of(std::next(stride_it), strides.end(), not_zero))
{
binary_broadcast(result, arg1, arg2)(f); binary_broadcast(result, arg1, arg2)(f);
return; return;
} }
......
...@@ -77,7 +77,7 @@ struct auto_print ...@@ -77,7 +77,7 @@ struct auto_print
}; };
std::array<std::function<void()>, 2> auto_print::handlers = {}; std::array<std::function<void()>, 2> auto_print::handlers = {};
template<class T> template <class T>
auto get_hash(const T& x) auto get_hash(const T& x)
{ {
return std::hash<T>{}(x); return std::hash<T>{}(x);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment