Unverified Commit 93338e62 authored by mcarilli's avatar mcarilli Committed by GitHub
Browse files

Give multi-tensor L2 norm the ability to compute norms per-tensor as well as globally (#333)

* Existing tests passing, still need to add per-tensor tests

* Test is passing, still need to measure performance

* ILP for l2norm functor
parent a151575c
......@@ -14,10 +14,11 @@ void multi_tensor_axpby_cuda(
float b,
int arg_to_check);
at::Tensor multi_tensor_l2norm_cuda(
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists);
std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
......
......@@ -20,6 +20,7 @@ template<int n> struct TensorListMetadata
int sizes[depth_to_max_tensors[n-1]];
unsigned char block_to_tensor[depth_to_max_blocks[n-1]];
int block_to_chunk[depth_to_max_blocks[n-1]]; // I fear this needs to be a full int.
int start_tensor_this_launch;
};
......@@ -66,6 +67,7 @@ void multi_tensor_apply(
auto stream = at::cuda::getCurrentCUDAStream();
tl.start_tensor_this_launch = 0;
int loc_block_info = 0;
int loc_tensor_info = 0;
for(int t = 0; t < ntensors; t++)
......@@ -106,6 +108,7 @@ void multi_tensor_apply(
{
// std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 << std::endl;
loc_tensor_info = 0;
tl.start_tensor_this_launch = t + 1;
}
else
{
......@@ -114,6 +117,7 @@ void multi_tensor_apply(
for(int d = 0; d < depth; d++)
tl.addresses[d][0] = tl.addresses[d][loc_tensor_info-1];
loc_tensor_info = 1;
tl.start_tensor_this_launch = t;
}
}
}
......
......@@ -20,7 +20,10 @@ struct L2NormFunctor
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<1>& tl,
float* output)
float* output,
float* output_per_tensor,
bool per_tensor,
int max_chunks_per_tensor)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
......@@ -35,47 +38,114 @@ struct L2NormFunctor
n -= chunk_idx*chunk_size;
__shared__ float vals[512];
__shared__ float s_vals[512];
// Non-divergent exit condition for __syncthreads, not necessary here
float val = 0;
for(int i = threadIdx.x; i < n && i < chunk_size; i += blockDim.x)
float vals[ILP]; // = {0}; // this probably works too but I want to be sure...
for(int i = 0; i < ILP; i++)
vals[i] = 0.f;
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP)
{
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
float next = static_cast<float>(x[i]);
val += next*next;
vals[ii] += next*next;
}
}
}
float final = reduce_block_into_lanes(vals, val);
float val = 0.f;
for(int i = 0; i < ILP; i++)
val += vals[i];
float final = reduce_block_into_lanes(s_vals, val);
if(threadIdx.x == 0)
{
if(!isfinite(final))
*noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.
output[blockIdx.x] += final;
if(per_tensor)
output_per_tensor[(tl.start_tensor_this_launch + tensor_loc)*max_chunks_per_tensor + chunk_idx] = final;
}
}
};
__global__ void cleanup(float* x, float* ret)
__global__ void cleanup(
float* output,
float* output_per_tensor,
float* ret,
float* ret_per_tensor,
bool per_tensor,
int max_chunks_per_tensor)
{
__shared__ float vals[512];
if(blockIdx.x == 0)
{
float val = 0;
if(threadIdx.x < 320)
val = x[threadIdx.x];
val = output[threadIdx.x];
float final = reduce_block_into_lanes(vals, val);
if(threadIdx.x == 0)
*ret = sqrt(final);
}
if(per_tensor)
{
float* output_this_tensor = output_per_tensor + blockIdx.x*max_chunks_per_tensor;
float val = 0;
for(int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x)
val += output_this_tensor[i];
float final = reduce_block_into_lanes(vals, val);
if(threadIdx.x == 0)
ret_per_tensor[blockIdx.x] = sqrt(final);
}
}
at::Tensor multi_tensor_l2norm_cuda(
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists)
std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python)
{
auto output = at::zeros({320}, tensor_lists[0][0].options().dtype(at::kFloat));
bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false;
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
auto output = at::zeros({320}, float_options);
at::Tensor output_per_tensor;
at::Tensor ret_per_tensor;
int ntensors = tensor_lists[0].size();
int max_chunks_per_tensor = -1;
if(per_tensor)
{
for(int t = 0; t < ntensors; t++)
{
int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size;
if(max_chunks_this_tensor > max_chunks_per_tensor)
max_chunks_per_tensor = max_chunks_this_tensor;
}
output_per_tensor = at::zeros({ntensors*max_chunks_per_tensor}, float_options);
ret_per_tensor = at::empty({ntensors}, float_options);
}
else
{
ret_per_tensor = at::empty({0}, float_options);
}
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
multi_tensor_apply<1>(
......@@ -84,7 +154,10 @@ at::Tensor multi_tensor_l2norm_cuda(
noop_flag,
tensor_lists,
L2NormFunctor<scalar_t_0>(),
output.data<float>());)
output.data<float>(),
per_tensor ? output_per_tensor.data<float>() : nullptr,
per_tensor,
max_chunks_per_tensor);)
AT_CUDA_CHECK(cudaGetLastError());
......@@ -95,6 +168,13 @@ at::Tensor multi_tensor_l2norm_cuda(
// logic, but keeping it simple for now
auto ret = at::empty({1}, output.options());
auto stream = at::cuda::getCurrentCUDAStream();
cleanup<<<1, 512, 0, stream>>>(output.data<float>(), ret.data<float>());
return ret;
cleanup<<<per_tensor ? ntensors : 1, 512, 0, stream>>>(
output.data<float>(),
per_tensor ? output_per_tensor.data<float>() : nullptr,
ret.data<float>(),
per_tensor ? ret_per_tensor.data<float>() : nullptr,
per_tensor,
max_chunks_per_tensor);
return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);
}
......@@ -32,7 +32,7 @@ class TestMultiTensorL2Norm(unittest.TestCase):
pass
# The tensor creation here is written for convenience, not speed.
def l2norm(self, sizea, sizeb, applier, repeat_tensors, in_type):
def l2norm(self, sizea, sizeb, applier, repeat_tensors, in_type, per_tensor):
self.overflow_buf.zero_()
a = torch.cuda.FloatTensor(sizea).fill_(self.val)
b = torch.cuda.FloatTensor(sizeb).fill_(self.val)
......@@ -41,12 +41,18 @@ class TestMultiTensorL2Norm(unittest.TestCase):
for i in range(repeat_tensors):
in_list += [a.clone().to(in_type), b.clone().to(in_type)]
norm = applier(multi_tensor_l2norm, self.overflow_buf, [in_list])
if per_tensor:
norm, norm_per_tensor = applier(multi_tensor_l2norm, self.overflow_buf, [in_list], True)
normab = torch.cat((a.norm().view(1), b.norm().view(1)))
norm_per_tensor = norm_per_tensor.view(-1, 2)
else:
norm, _ = applier(multi_tensor_l2norm, self.overflow_buf, [in_list], True)
reference = torch.cuda.FloatTensor((sizea + sizeb)*repeat_tensors).fill_(self.val).norm()
self.assertTrue(torch.allclose(norm, reference))
if per_tensor:
self.assertTrue(torch.allclose(norm_per_tensor, normab))
self.assertTrue(self.overflow_buf.item() == 0)
@unittest.skipIf(disabled, "amp_C is unavailable")
......@@ -72,7 +78,8 @@ class TestMultiTensorL2Norm(unittest.TestCase):
for applier in appliers:
for repeat in repeat_tensors:
for in_type in (torch.float32, torch.float16):
self.l2norm(sizea, sizeb, applier, repeat, in_type, )
for per_tensor in (False, True):
self.l2norm(sizea, sizeb, applier, repeat, in_type, per_tensor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment