Commit 683b6e0e authored by Michael Carilli's avatar Michael Carilli
Browse files

Quick kernel to clean up l2norm

parent 1a48b26b
......@@ -56,6 +56,20 @@ struct L2NormFunctor
}
};
__global__ void cleanup(float* x, float* ret)
{
__shared__ float vals[512];
float val = 0;
if(threadIdx.x < 320)
val = x[threadIdx.x];
float final = reduce_block_into_lanes(vals, val);
if(threadIdx.x == 0)
*ret = sqrt(final);
}
at::Tensor multi_tensor_l2norm_cuda(
int chunk_size,
at::Tensor noop_flag,
......@@ -76,8 +90,11 @@ at::Tensor multi_tensor_l2norm_cuda(
// AT_CUDA_CHECK(cudaDeviceSynchronize());
// This involves two more small kernel launches, but will be negligible end to end.
// This involves one more small kernel launches, but will be negligible end to end.
// I could get rid of these by hacking the functor + multi tensor harness with persistence
// logic, but keeping it simple for now
return output.sum().sqrt();
auto ret = at::empty({1}, output.options());
auto stream = at::cuda::getCurrentCUDAStream();
cleanup<<<1, 512, 0, stream>>>(output.data<float>(), ret.data<float>());
return ret;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment