Commit fb7d4e1d authored by Michael Carilli's avatar Michael Carilli
Browse files

Fleshed out Cuda version checking and compiling for multiple arches

parent d17a015f
......@@ -49,6 +49,7 @@ class FP16_Module(nn.Module):
def forward(self, *inputs, **kwargs):
return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs))
# TODO: Update overflow check + downscale to use Carl's fused kernel.
class FP16_Optimizer(object):
"""
:class:`FP16_Optimizer` is designed to wrap an existing PyTorch optimizer,
......
......@@ -61,6 +61,10 @@ void scale_check_overflow
{
AT_CHECK(grads.type().is_cuda(), "x must be a CUDA tensor");
AT_CHECK(overflow_buf.type().is_cuda(), "y must be a CUDA tensor");
// Make sure we are downscaling the FP32 master grads
AT_CHECK
(grads.type().scalarType() == at::ScalarType::Float,
"grads supplied to scale_check_overflow should be fp32 (master grads).")
scale_check_overflow_cuda(grads, scale, overflow_buf);
}
......
......@@ -13,6 +13,10 @@
// It makes sense to lock the type to "float" here because the downscaling
// should only be applied to the FP32 master gradients. Also, if "in" were
// a different type, it would require divergent code for the vectorized load logic.
// TODO:
// Update overflow check to use reduction from kernel_utils.cuh with
// ReduceOp from THCTensorMathReduce.cuh.
__global__ void scale_reduce_overflow
(float *in,
size_t n,
......
import torch.cuda
import os
import re
import subprocess
from setuptools import setup, find_packages
from distutils.command.clean import clean
from torch.utils.cpp_extension import CppExtension, CUDAExtension
......@@ -6,7 +9,53 @@ from torch.utils.cpp_extension import CUDA_HOME
# TODO: multiple modules, so we don't have to route all interfaces through
# the same interface.cpp file?
if torch.cuda.is_available() and CUDA_HOME is not None:
if not torch.cuda.is_available():
print("Warning: Torch did not find available GPUs on this system.\n",
"If your intention is to cross-compile, this is not an error.")
def find(path, regex_func, collect=False):
collection = [] if collect else None
for root, dirs, files in os.walk(path):
for file in files:
if regex_func(file):
if collect:
collection.append(os.path.join(root, file))
else:
return os.path.join(root, file)
return list(set(collection))
def get_cuda_version():
NVCC = find(CUDA_HOME+os.sep+"bin",
re.compile('nvcc$').search)
print("Found NVCC = ", NVCC)
# Parse output of nvcc to get cuda major version
nvcc_output = subprocess.check_output([NVCC, '--version']).decode("utf-8")
CUDA_LIB = re.compile(', V[0-9]+\.[0-9]+\.[0-9]+').search(nvcc_output).group(0).split('V')[1]
print("Found CUDA_LIB = ", CUDA_LIB)
CUDA_MAJOR_VERSION = int(CUDA_LIB.split('.')[0])
print("Found CUDA_MAJOR_VERSION = ", CUDA_MAJOR_VERSION)
if CUDA_MAJOR_VERSION < 8:
raise RuntimeError("APex requires CUDA 8.0 or newer")
return CUDA_MAJOR_VERSION
if CUDA_HOME is not None:
print("Found CUDA_HOME = ", CUDA_HOME)
CUDA_MAJOR_VERSION = get_cuda_version()
gencodes = ['-gencode', 'arch=compute_52,code=sm_52',
'-gencode', 'arch=compute_60,code=sm_60',
'-gencode', 'arch=compute_61,code=sm_61',]
if CUDA_MAJOR_VERSION > 8:
gencodes += ['-gencode', 'arch=compute_70,code=sm_70',
'-gencode', 'arch=compute_70,code=compute_70',]
ext_modules = []
extension = CUDAExtension(
'apex._C', [
......@@ -16,10 +65,10 @@ if torch.cuda.is_available() and CUDA_HOME is not None:
'csrc/scale_cuda.cu',
],
extra_compile_args={'cxx': ['-g'],
'nvcc': ['-O2', '-arch=sm_70']}) # TODO: compile for all arches.
'nvcc': ['-O3'] + gencodes})
ext_modules.append(extension)
else:
raise RuntimeError("Apex requires Cuda 9.0 or higher")
raise RuntimeError("Could not find Cuda install directory")
setup(
name='apex',
......
......@@ -9,13 +9,13 @@ torch.cuda.manual_seed(2)
# torch.cuda.manual_seed_all(2)
torch.set_printoptions(precision=10)
rows = 1 # 321
cols = 4096 # 33
fast = 4096 # 185
rows = 321 # 1
cols = 33 # 4096
fast = 185 # 4096
dims = rows, cols, fast
dim = 2
CUDA_HALF = False
dim = 0
CUDA_HALF = True
RAND = True # If false, input gradients (the result of the backward pass)
# should be analytically zero.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment