"git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "f69a3a44e3f54537a545659c62001247b079a62e"
Unverified Commit 93338e62 authored by mcarilli's avatar mcarilli Committed by GitHub
Browse files

Give multi-tensor L2 norm the ability to compute norms per-tensor as well as globally (#333)

* Existing tests passing, still need to add per-tensor tests

* Test is passing, still need to measure performance

* ILP for l2norm functor
parent a151575c
...@@ -14,10 +14,11 @@ void multi_tensor_axpby_cuda( ...@@ -14,10 +14,11 @@ void multi_tensor_axpby_cuda(
float b, float b,
int arg_to_check); int arg_to_check);
at::Tensor multi_tensor_l2norm_cuda( std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
int chunk_size, int chunk_size,
at::Tensor noop_flag, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists); std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_scale", &multi_tensor_scale_cuda, m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
......
...@@ -20,6 +20,7 @@ template<int n> struct TensorListMetadata ...@@ -20,6 +20,7 @@ template<int n> struct TensorListMetadata
int sizes[depth_to_max_tensors[n-1]]; int sizes[depth_to_max_tensors[n-1]];
unsigned char block_to_tensor[depth_to_max_blocks[n-1]]; unsigned char block_to_tensor[depth_to_max_blocks[n-1]];
int block_to_chunk[depth_to_max_blocks[n-1]]; // I fear this needs to be a full int. int block_to_chunk[depth_to_max_blocks[n-1]]; // I fear this needs to be a full int.
int start_tensor_this_launch;
}; };
...@@ -66,6 +67,7 @@ void multi_tensor_apply( ...@@ -66,6 +67,7 @@ void multi_tensor_apply(
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
tl.start_tensor_this_launch = 0;
int loc_block_info = 0; int loc_block_info = 0;
int loc_tensor_info = 0; int loc_tensor_info = 0;
for(int t = 0; t < ntensors; t++) for(int t = 0; t < ntensors; t++)
...@@ -106,6 +108,7 @@ void multi_tensor_apply( ...@@ -106,6 +108,7 @@ void multi_tensor_apply(
{ {
// std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 << std::endl; // std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 << std::endl;
loc_tensor_info = 0; loc_tensor_info = 0;
tl.start_tensor_this_launch = t + 1;
} }
else else
{ {
...@@ -114,6 +117,7 @@ void multi_tensor_apply( ...@@ -114,6 +117,7 @@ void multi_tensor_apply(
for(int d = 0; d < depth; d++) for(int d = 0; d < depth; d++)
tl.addresses[d][0] = tl.addresses[d][loc_tensor_info-1]; tl.addresses[d][0] = tl.addresses[d][loc_tensor_info-1];
loc_tensor_info = 1; loc_tensor_info = 1;
tl.start_tensor_this_launch = t;
} }
} }
} }
......
...@@ -16,11 +16,14 @@ ...@@ -16,11 +16,14 @@
template<typename x_t> template<typename x_t>
struct L2NormFunctor struct L2NormFunctor
{ {
__device__ __forceinline__ void operator()( __device__ __forceinline__ void operator()(
int chunk_size, int chunk_size,
volatile int* noop_gmem, volatile int* noop_gmem,
TensorListMetadata<1>& tl, TensorListMetadata<1>& tl,
float* output) float* output,
float* output_per_tensor,
bool per_tensor,
int max_chunks_per_tensor)
{ {
// I'd like this kernel to propagate infs/nans. // I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1) // if(*noop_gmem == 1)
...@@ -35,47 +38,114 @@ struct L2NormFunctor ...@@ -35,47 +38,114 @@ struct L2NormFunctor
n -= chunk_idx*chunk_size; n -= chunk_idx*chunk_size;
__shared__ float vals[512]; __shared__ float s_vals[512];
// Non-divergent exit condition for __syncthreads, not necessary here float vals[ILP]; // = {0}; // this probably works too but I want to be sure...
float val = 0; for(int i = 0; i < ILP; i++)
for(int i = threadIdx.x; i < n && i < chunk_size; i += blockDim.x) vals[i] = 0.f;
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP)
{ {
float next = static_cast<float>(x[i]); #pragma unroll
val += next*next; for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
float next = static_cast<float>(x[i]);
vals[ii] += next*next;
}
}
} }
float final = reduce_block_into_lanes(vals, val); float val = 0.f;
for(int i = 0; i < ILP; i++)
val += vals[i];
float final = reduce_block_into_lanes(s_vals, val);
if(threadIdx.x == 0) if(threadIdx.x == 0)
{ {
if(!isfinite(final)) if(!isfinite(final))
*noop_gmem = 1; // Blindly fire off a write. These will race but that's ok. *noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.
output[blockIdx.x] += final; output[blockIdx.x] += final;
if(per_tensor)
output_per_tensor[(tl.start_tensor_this_launch + tensor_loc)*max_chunks_per_tensor + chunk_idx] = final;
} }
} }
}; };
__global__ void cleanup(float* x, float* ret)
__global__ void cleanup(
float* output,
float* output_per_tensor,
float* ret,
float* ret_per_tensor,
bool per_tensor,
int max_chunks_per_tensor)
{ {
__shared__ float vals[512]; __shared__ float vals[512];
float val = 0; if(blockIdx.x == 0)
if(threadIdx.x < 320) {
val = x[threadIdx.x]; float val = 0;
if(threadIdx.x < 320)
val = output[threadIdx.x];
float final = reduce_block_into_lanes(vals, val);
if(threadIdx.x == 0)
*ret = sqrt(final);
}
if(per_tensor)
{
float* output_this_tensor = output_per_tensor + blockIdx.x*max_chunks_per_tensor;
float val = 0;
for(int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x)
val += output_this_tensor[i];
float final = reduce_block_into_lanes(vals, val); float final = reduce_block_into_lanes(vals, val);
if(threadIdx.x == 0) if(threadIdx.x == 0)
*ret = sqrt(final); ret_per_tensor[blockIdx.x] = sqrt(final);
}
} }
at::Tensor multi_tensor_l2norm_cuda(
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
int chunk_size, int chunk_size,
at::Tensor noop_flag, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists) std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python)
{ {
auto output = at::zeros({320}, tensor_lists[0][0].options().dtype(at::kFloat)); bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false;
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
auto output = at::zeros({320}, float_options);
at::Tensor output_per_tensor;
at::Tensor ret_per_tensor;
int ntensors = tensor_lists[0].size();
int max_chunks_per_tensor = -1;
if(per_tensor)
{
for(int t = 0; t < ntensors; t++)
{
int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size;
if(max_chunks_this_tensor > max_chunks_per_tensor)
max_chunks_per_tensor = max_chunks_this_tensor;
}
output_per_tensor = at::zeros({ntensors*max_chunks_per_tensor}, float_options);
ret_per_tensor = at::empty({ntensors}, float_options);
}
else
{
ret_per_tensor = at::empty({0}, float_options);
}
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda", DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
multi_tensor_apply<1>( multi_tensor_apply<1>(
...@@ -84,7 +154,10 @@ at::Tensor multi_tensor_l2norm_cuda( ...@@ -84,7 +154,10 @@ at::Tensor multi_tensor_l2norm_cuda(
noop_flag, noop_flag,
tensor_lists, tensor_lists,
L2NormFunctor<scalar_t_0>(), L2NormFunctor<scalar_t_0>(),
output.data<float>());) output.data<float>(),
per_tensor ? output_per_tensor.data<float>() : nullptr,
per_tensor,
max_chunks_per_tensor);)
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
...@@ -95,6 +168,13 @@ at::Tensor multi_tensor_l2norm_cuda( ...@@ -95,6 +168,13 @@ at::Tensor multi_tensor_l2norm_cuda(
// logic, but keeping it simple for now // logic, but keeping it simple for now
auto ret = at::empty({1}, output.options()); auto ret = at::empty({1}, output.options());
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
cleanup<<<1, 512, 0, stream>>>(output.data<float>(), ret.data<float>()); cleanup<<<per_tensor ? ntensors : 1, 512, 0, stream>>>(
return ret; output.data<float>(),
per_tensor ? output_per_tensor.data<float>() : nullptr,
ret.data<float>(),
per_tensor ? ret_per_tensor.data<float>() : nullptr,
per_tensor,
max_chunks_per_tensor);
return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);
} }
...@@ -32,7 +32,7 @@ class TestMultiTensorL2Norm(unittest.TestCase): ...@@ -32,7 +32,7 @@ class TestMultiTensorL2Norm(unittest.TestCase):
pass pass
# The tensor creation here is written for convenience, not speed. # The tensor creation here is written for convenience, not speed.
def l2norm(self, sizea, sizeb, applier, repeat_tensors, in_type): def l2norm(self, sizea, sizeb, applier, repeat_tensors, in_type, per_tensor):
self.overflow_buf.zero_() self.overflow_buf.zero_()
a = torch.cuda.FloatTensor(sizea).fill_(self.val) a = torch.cuda.FloatTensor(sizea).fill_(self.val)
b = torch.cuda.FloatTensor(sizeb).fill_(self.val) b = torch.cuda.FloatTensor(sizeb).fill_(self.val)
...@@ -41,12 +41,18 @@ class TestMultiTensorL2Norm(unittest.TestCase): ...@@ -41,12 +41,18 @@ class TestMultiTensorL2Norm(unittest.TestCase):
for i in range(repeat_tensors): for i in range(repeat_tensors):
in_list += [a.clone().to(in_type), b.clone().to(in_type)] in_list += [a.clone().to(in_type), b.clone().to(in_type)]
if per_tensor:
norm = applier(multi_tensor_l2norm, self.overflow_buf, [in_list]) norm, norm_per_tensor = applier(multi_tensor_l2norm, self.overflow_buf, [in_list], True)
normab = torch.cat((a.norm().view(1), b.norm().view(1)))
norm_per_tensor = norm_per_tensor.view(-1, 2)
else:
norm, _ = applier(multi_tensor_l2norm, self.overflow_buf, [in_list], True)
reference = torch.cuda.FloatTensor((sizea + sizeb)*repeat_tensors).fill_(self.val).norm() reference = torch.cuda.FloatTensor((sizea + sizeb)*repeat_tensors).fill_(self.val).norm()
self.assertTrue(torch.allclose(norm, reference)) self.assertTrue(torch.allclose(norm, reference))
if per_tensor:
self.assertTrue(torch.allclose(norm_per_tensor, normab))
self.assertTrue(self.overflow_buf.item() == 0) self.assertTrue(self.overflow_buf.item() == 0)
@unittest.skipIf(disabled, "amp_C is unavailable") @unittest.skipIf(disabled, "amp_C is unavailable")
...@@ -72,7 +78,8 @@ class TestMultiTensorL2Norm(unittest.TestCase): ...@@ -72,7 +78,8 @@ class TestMultiTensorL2Norm(unittest.TestCase):
for applier in appliers: for applier in appliers:
for repeat in repeat_tensors: for repeat in repeat_tensors:
for in_type in (torch.float32, torch.float16): for in_type in (torch.float32, torch.float16):
self.l2norm(sizea, sizeb, applier, repeat, in_type, ) for per_tensor in (False, True):
self.l2norm(sizea, sizeb, applier, repeat, in_type, per_tensor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment