Commit 1a48b26b authored by Michael Carilli's avatar Michael Carilli
Browse files

Kernel + sizes stress test

parent e57f5d0e
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include "type_shim.h"
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define ILP 4
template<typename x_t>
struct L2NormFunctor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<1>& tl,
float* output)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
x_t* x = (x_t*)tl.addresses[0][tensor_loc];
x += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
__shared__ float vals[512];
// Non-divergent exit condition for __syncthreads, not necessary here
float val = 0;
for(int i = threadIdx.x; i < n && i < chunk_size; i += blockDim.x)
{
float next = static_cast<float>(x[i]);
val += next*next;
}
float final = reduce_block_into_lanes(vals, val);
if(threadIdx.x == 0)
{
if(!isfinite(final))
*noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.
output[blockIdx.x] += final;
}
}
};
at::Tensor multi_tensor_l2norm_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists)
{
auto output = at::zeros({320}, tensor_lists[0][0].options().dtype(at::ScalarType::Float));
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
multi_tensor_apply<1>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
L2NormFunctor<scalar_t_0>(),
output.data<float>());)
AT_CUDA_CHECK(cudaGetLastError());
// AT_CUDA_CHECK(cudaDeviceSynchronize());
// This involves two more small kernel launches, but will be negligible end to end.
// I could get rid of these by hacking the functor + multi tensor harness with persistence
// logic, but keeping it simple for now
return output.sum().sqrt();
}
...@@ -33,12 +33,12 @@ struct TypeShim ...@@ -33,12 +33,12 @@ struct TypeShim
} }
template<typename T, typename ReduceOp> template<typename T>
__device__ __forceinline__ T reduce_block_into_lanes __device__ __forceinline__ T reduce_block_into_lanes
(T *x, (T *x,
T val, T val,
int lanes, int lanes=1,
bool share_result) // lanes is intended to be <= 32. bool share_result=false) // lanes is intended to be <= 32.
{ {
int tid = threadIdx.x + threadIdx.y*blockDim.x; int tid = threadIdx.x + threadIdx.y*blockDim.x;
int blockSize = blockDim.x*blockDim.y; // blockSize is intended to be a multiple of 32. int blockSize = blockDim.x*blockDim.y; // blockSize is intended to be a multiple of 32.
......
import unittest
import functools as ft
import itertools as it
from apex import amp
import torch
from torch import nn
import torch.nn.functional as F
from utils import common_init, HALF, FLOAT,\
ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
try:
import amp_C
from amp_C import multi_tensor_l2norm
from apex.multi_tensor_apply import MultiTensorApply
disabled = False
except ImportError as err:
print("amp_C fused kernels unavailable, disabling TestMultiTensorApply. ImportError was ", err)
disabled = True
class TestMultiTensorL2Norm(unittest.TestCase):
def setUp(self):
common_init(self)
self.val = 4.0
self.overflow_buf = torch.cuda.IntTensor(1).zero_()
def tearDown(self):
pass
# The tensor creation here is written for convenience, not speed.
def l2norm(self, sizea, sizeb, applier, repeat_tensors, in_type):
self.overflow_buf.zero_()
a = torch.cuda.FloatTensor(sizea).fill_(self.val)
b = torch.cuda.FloatTensor(sizeb).fill_(self.val)
in_list = []
for i in range(repeat_tensors):
in_list += [a.clone().to(in_type), b.clone().to(in_type)]
norm = applier(multi_tensor_l2norm, self.overflow_buf, [in_list])
reference = torch.cuda.FloatTensor((sizea + sizeb)*repeat_tensors).fill_(self.val).norm()
self.assertTrue(torch.allclose(norm, reference))
self.assertTrue(self.overflow_buf.item() == 0)
@unittest.skipIf(disabled, "amp_C is unavailable")
def test_fuzz(self):
input_size_pairs = (
(7777*77, 555*555),
(777, 555),
(555, 2048*32+1),
(2048*32+1, 555),
(555, 2048*32),
(2048*32, 555),
(33333, 555),
(555, 33333))
appliers = (
MultiTensorApply(2048*32),
MultiTensorApply(333),
MultiTensorApply(33333))
repeat_tensors = (
1,
55)
for sizea, sizeb in input_size_pairs:
for applier in appliers:
for repeat in repeat_tensors:
for in_type in (torch.float32, torch.float16):
self.l2norm(sizea, sizeb, applier, repeat, in_type, )
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment