Commit ac7dbf67 authored by Michael Carilli's avatar Michael Carilli
Browse files

Merge branch 'master' of https://github.com/NVIDIA/apex

parents 8437d295 5e552004
...@@ -6,7 +6,16 @@ void multi_tensor_scale_cuda( ...@@ -6,7 +6,16 @@ void multi_tensor_scale_cuda(
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<at::Tensor>> tensor_lists,
float scale); float scale);
void multi_tensor_axpby_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
float a,
float b);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_scale", &multi_tensor_scale_cuda, m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
"Fused overflow check + scale for a list of contiguous tensors"); "Fused overflow check + scale for a list of contiguous tensors");
m.def("multi_tensor_axpby", &multi_tensor_axpby_cuda,
"out = a*x + b*y for a list of contiguous tensors");
} }
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include "type_shim.h"
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define ILP 4
template<typename x_t, typename y_t, typename out_t>
struct AxpbyFunctor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<3>& tl,
float a,
float b)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
x_t* x = (x_t*)tl.addresses[0][tensor_loc];
x += chunk_idx*chunk_size;
y_t* y = (y_t*)tl.addresses[1][tensor_loc];
y += chunk_idx*chunk_size;
out_t* out = (out_t*)tl.addresses[2][tensor_loc];
out += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
// Non-divergent exit condition for __syncthreads, not necessary here
float xs[ILP];
float ys[ILP];
for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP)
{
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
xs[ii] = 0;
ys[ii] = 0;
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
xs[ii] = static_cast<float>(x[i]);
ys[ii] = static_cast<float>(y[i]);
}
}
// see note in multi_tensor_scale_kernel.cu
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
if(isfinite(xs[ii]) && isfinite(ys[ii]))
out[i] = static_cast<out_t>(a*xs[ii] + b*ys[ii]);
else
{
out[i] = static_cast<out_t>(a*xs[ii] + b*ys[ii]);
*noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.
}
}
}
}
};
void multi_tensor_axpby_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
float a,
float b)
{
using namespace at;
// The output (downscaled) type is always float.
// If build times suffer, think about where to put this dispatch,
// and what logic should be moved out of multi_tensor_apply.
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_axpby_cuda",
DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "multi_tensor_axpby_cuda",
DISPATCH_FLOAT_AND_HALF(tensor_lists[2][0].scalar_type(), 2, "multi_tensor_axpby_cuda",
multi_tensor_apply<3>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AxpbyFunctor<scalar_t_0, scalar_t_1, scalar_t_2>(),
a,
b); )))
AT_CUDA_CHECK(cudaGetLastError());
// AT_CUDA_CHECK(cudaDeviceSynchronize());
}
...@@ -12,3 +12,22 @@ struct TypeShim ...@@ -12,3 +12,22 @@ struct TypeShim
// Enable dispatch switch statements to take *this directly for post-3aeb78 // Enable dispatch switch statements to take *this directly for post-3aeb78
operator at::ScalarType(){ return payload.scalarType(); }; operator at::ScalarType(){ return payload.scalarType(); };
}; };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
...@@ -48,10 +48,12 @@ if "--cuda_ext" in sys.argv: ...@@ -48,10 +48,12 @@ if "--cuda_ext" in sys.argv:
ext_modules.append( ext_modules.append(
CUDAExtension(name='amp_C', CUDAExtension(name='amp_C',
sources=['csrc/amp_C_frontend.cpp', sources=['csrc/amp_C_frontend.cpp',
'csrc/multi_tensor_scale_kernel.cu'], 'csrc/multi_tensor_scale_kernel.cu',
'csrc/multi_tensor_axpby_kernel.cu'],
extra_compile_args={'cxx': ['-O3'], extra_compile_args={'cxx': ['-O3'],
'nvcc':['-lineinfo', 'nvcc':['-lineinfo',
'-O3', '-O3',
# '--resource-usage',
'--use_fast_math']})) '--use_fast_math']}))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fused_adam_cuda', CUDAExtension(name='fused_adam_cuda',
......
import unittest
import functools as ft
import itertools as it
from apex import amp
import torch
from torch import nn
import torch.nn.functional as F
from utils import common_init, HALF, FLOAT,\
ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
try:
import amp_C
from amp_C import multi_tensor_axpby
from apex.multi_tensor_apply import MultiTensorApply
disabled = False
except ImportError as err:
print("amp_C fused kernels unavailable, disabling TestMultiTensorApply. ImportError was ", err)
disabled = True
class TestMultiTensorAxpby(unittest.TestCase):
def setUp(self):
common_init(self)
self.a = 2.0
self.b = 8.0
self.xval = 4.0
self.yval = 16.0
self.overflow_buf = torch.cuda.IntTensor(1).zero_()
self.ref = torch.cuda.FloatTensor([136.0])
def tearDown(self):
pass
# The tensor creation here is written for convenience, not speed.
def axpby(self, sizea, sizeb, applier, repeat_tensors,
x_type, y_type, out_type, inplace=False):
self.overflow_buf.zero_()
t1 = torch.cuda.FloatTensor(sizea).fill_(1.0)
t2 = torch.cuda.FloatTensor(sizeb).fill_(1.0)
y_list = []
for i in range(repeat_tensors):
y_list += [t1.clone().to(y_type)*self.yval, t2.clone().to(y_type)*self.yval]
x_list = [x.clone().to(x_type)*(self.xval/self.yval) for x in y_list]
if inplace:
out_list = y_list
else:
out_list = [out.clone().to(out_type)*3.0 for out in y_list]
applier(multi_tensor_axpby, self.overflow_buf, [x_list, y_list, out_list], self.a, self.b)
self.assertTrue(all([torch.allclose(out, self.ref.to(out_type)) for out in out_list]),
msg="{} {} {} {} {} {} {}".format(sizea, sizeb, repeat_tensors,
x_type, y_type, out_type, inplace))
self.assertTrue(self.overflow_buf.item() == 0,
msg="{} {} {} {} {} {} {}".format(sizea, sizeb, repeat_tensors,
x_type, y_type, out_type, inplace))
# def find_inf(self, sizea, sizeb, applier, repeat_tensors, in_type, out_type, t, ind, val, inplace=False):
# self.overflow_buf.zero_()
# a = torch.cuda.FloatTensor(sizea).fill_(self.scale)
# b = torch.cuda.FloatTensor(sizeb).fill_(self.scale)
# out_list = []
# for i in range(repeat_tensors):
# out_list += [a.clone().to(out_type), b.clone().to(out_type)]
# if inplace:
# in_list = out_list
# else:
# in_list = [out.clone().to(in_type) for out in out_list]
# applier(multi_tensor_scale, self.overflow_buf, [in_list, out_list], 1./self.scale)
# self.overflow_buf.zero_()
# in_list[t][ind] = val
# applier(multi_tensor_scale, self.overflow_buf, [in_list, out_list], 1./self.scale)
# self.assertTrue(self.overflow_buf.item())
@unittest.skipIf(disabled, "amp_C is unavailable")
def test_fuzz(self):
input_size_pairs = (
(7777*77, 555*555),
(777, 555),
(555, 2048*32+1),
(2048*32+1, 555),
(555, 2048*32),
(2048*32, 555),
(33333, 555),
(555, 33333))
appliers = (
MultiTensorApply(2048*32),
MultiTensorApply(333),
MultiTensorApply(33333))
repeat_tensors = (
1,
55)
for sizea, sizeb in input_size_pairs:
for applier in appliers:
for repeat in repeat_tensors:
for x_type in (torch.float32, torch.float16):
for y_type in (torch.float32, torch.float16):
for out_type in (torch.float32, torch.float16):
for inplace in (True, False):
if inplace is True and (y_type is not out_type):
continue
else:
self.axpby(sizea, sizeb, applier, repeat,
x_type, y_type, out_type, inplace=inplace)
# self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
# 0, 0, float('nan'), inplace=inplace)
# self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
# 2*repeat-1, sizeb-1, float('inf'), inplace=inplace)
# self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
# 2*(repeat//2), sizea//2, float('inf'), inplace=inplace)
if __name__ == '__main__':
unittest.main()
...@@ -24,12 +24,11 @@ except ImportError as err: ...@@ -24,12 +24,11 @@ except ImportError as err:
class TestMultiTensorScale(unittest.TestCase): class TestMultiTensorScale(unittest.TestCase):
def setUp(self): def setUp(self):
common_init(self)
self.scale = 4.0 self.scale = 4.0
self.overflow_buf = torch.cuda.IntTensor(1).zero_() self.overflow_buf = torch.cuda.IntTensor(1).zero_()
self.ref = torch.cuda.FloatTensor([1.0]) self.ref = torch.cuda.FloatTensor([1.0])
common_init(self)
def tearDown(self): def tearDown(self):
pass pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment