Commit 683b6e0e authored by Michael Carilli's avatar Michael Carilli
Browse files

Quick kernel to clean up l2norm

parent 1a48b26b
...@@ -56,6 +56,20 @@ struct L2NormFunctor ...@@ -56,6 +56,20 @@ struct L2NormFunctor
} }
}; };
__global__ void cleanup(float* x, float* ret)
{
__shared__ float vals[512];
float val = 0;
if(threadIdx.x < 320)
val = x[threadIdx.x];
float final = reduce_block_into_lanes(vals, val);
if(threadIdx.x == 0)
*ret = sqrt(final);
}
at::Tensor multi_tensor_l2norm_cuda( at::Tensor multi_tensor_l2norm_cuda(
int chunk_size, int chunk_size,
at::Tensor noop_flag, at::Tensor noop_flag,
...@@ -76,8 +90,11 @@ at::Tensor multi_tensor_l2norm_cuda( ...@@ -76,8 +90,11 @@ at::Tensor multi_tensor_l2norm_cuda(
// AT_CUDA_CHECK(cudaDeviceSynchronize()); // AT_CUDA_CHECK(cudaDeviceSynchronize());
// This involves two more small kernel launches, but will be negligible end to end. // This involves one more small kernel launches, but will be negligible end to end.
// I could get rid of these by hacking the functor + multi tensor harness with persistence // I could get rid of these by hacking the functor + multi tensor harness with persistence
// logic, but keeping it simple for now // logic, but keeping it simple for now
return output.sum().sqrt(); auto ret = at::empty({1}, output.options());
auto stream = at::cuda::getCurrentCUDAStream();
cleanup<<<1, 512, 0, stream>>>(output.data<float>(), ret.data<float>());
return ret;
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment