Commit e733e78c authored by Carl Case's avatar Carl Case Committed by Michael Carilli
Browse files

Initial support for automatic mixed precision

parent a3059288
...@@ -114,13 +114,18 @@ for i, entry in enumerate(libaten_names): ...@@ -114,13 +114,18 @@ for i, entry in enumerate(libaten_names):
aten_h = find(torch_dir, re.compile("aten.h", re.IGNORECASE).search, False) aten_h = find(torch_dir, re.compile("aten.h", re.IGNORECASE).search, False)
include_dirs = [os.path.dirname(os.path.dirname(aten_h))] torch_inc = os.path.dirname(os.path.dirname(aten_h))
include_dirs = [torch_inc]
library_dirs = [] library_dirs = []
for file in cuda_headers+headers: for file in cuda_headers+headers:
dir = os.path.dirname(file) dir = os.path.dirname(file)
if dir not in include_dirs: if dir not in include_dirs:
include_dirs.append(dir) include_dirs.append(dir)
# Object files that use the PyTorch cffi-extension interface
# They need special handling during compilation
cffi_objects = ['scale_kernel.o']
assert libaten, "Could not find PyTorch's libATen." assert libaten, "Could not find PyTorch's libATen."
assert aten_h, "Could not find PyTorch's ATen header." assert aten_h, "Could not find PyTorch's ATen header."
...@@ -178,18 +183,29 @@ def CompileCudaFiles(NVCC, CUDA_VERSION): ...@@ -178,18 +183,29 @@ def CompileCudaFiles(NVCC, CUDA_VERSION):
for dir in include_dirs: for dir in include_dirs:
nvcc_cmd.append("-I"+dir) nvcc_cmd.append("-I"+dir)
# Hack: compiling the cffi kernel code needs the TH{C}
# subdirs of include on path as well
for suffix in ['TH', 'THC']:
nvcc_cmd.append('-I{}/{}'.format(torch_inc, suffix))
for file in cuda_files: for file in cuda_files:
object_name = os.path.basename( object_name = os.path.basename(
os.path.splitext(file)[0]+".o" os.path.splitext(file)[0]+".o"
) )
object_file = os.path.join(buildir, object_name) object_file = os.path.join(buildir, object_name)
object_files.append(object_file) object_files.append(object_file)
file_opts = ['-c', file, '-o', object_file] file_opts = ['-c', file, '-o', object_file]
print(' '.join(nvcc_cmd+file_opts)) extra_args = []
subprocess.check_call(nvcc_cmd+file_opts) if object_name in cffi_objects:
for module in ['TH', 'THC']:
extra_args.append('-I{}/{}'.format(torch_inc, module))
build_args = nvcc_cmd + extra_args + file_opts
print(' '.join(build_args))
subprocess.check_call(build_args)
for object_file in object_files: for object_file in object_files:
extra_link_args.append(object_file) extra_link_args.append(object_file)
...@@ -228,4 +244,10 @@ setup( ...@@ -228,4 +244,10 @@ setup(
ext_modules=[cuda_ext,], ext_modules=[cuda_ext,],
description='PyTorch Extensions written by NVIDIA', description='PyTorch Extensions written by NVIDIA',
packages=find_packages(exclude=("build", "csrc", "include", "tests")), packages=find_packages(exclude=("build", "csrc", "include", "tests")),
# Require cffi
install_requires=["cffi>=1.0.0"],
setup_requires=["cffi>=1.0.0"],
cffi_modules=[os.path.join(os.path.dirname(__file__),
'build_cffi.py:extension')],
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment