Commit e733e78c authored by Carl Case's avatar Carl Case Committed by Michael Carilli
Browse files

Initial support for automatic mixed precision

parent a3059288
......@@ -114,13 +114,18 @@ for i, entry in enumerate(libaten_names):
aten_h = find(torch_dir, re.compile("aten.h", re.IGNORECASE).search, False)
include_dirs = [os.path.dirname(os.path.dirname(aten_h))]
torch_inc = os.path.dirname(os.path.dirname(aten_h))
include_dirs = [torch_inc]
library_dirs = []
for file in cuda_headers+headers:
dir = os.path.dirname(file)
if dir not in include_dirs:
include_dirs.append(dir)
# Object files that use the PyTorch cffi-extension interface
# They need special handling during compilation
cffi_objects = ['scale_kernel.o']
assert libaten, "Could not find PyTorch's libATen."
assert aten_h, "Could not find PyTorch's ATen header."
......@@ -178,18 +183,29 @@ def CompileCudaFiles(NVCC, CUDA_VERSION):
for dir in include_dirs:
nvcc_cmd.append("-I"+dir)
# Hack: compiling the cffi kernel code needs the TH{C}
# subdirs of include on path as well
for suffix in ['TH', 'THC']:
nvcc_cmd.append('-I{}/{}'.format(torch_inc, suffix))
for file in cuda_files:
object_name = os.path.basename(
os.path.splitext(file)[0]+".o"
)
object_file = os.path.join(buildir, object_name)
object_files.append(object_file)
file_opts = ['-c', file, '-o', object_file]
print(' '.join(nvcc_cmd+file_opts))
subprocess.check_call(nvcc_cmd+file_opts)
extra_args = []
if object_name in cffi_objects:
for module in ['TH', 'THC']:
extra_args.append('-I{}/{}'.format(torch_inc, module))
build_args = nvcc_cmd + extra_args + file_opts
print(' '.join(build_args))
subprocess.check_call(build_args)
for object_file in object_files:
extra_link_args.append(object_file)
......@@ -228,4 +244,10 @@ setup(
ext_modules=[cuda_ext,],
description='PyTorch Extensions written by NVIDIA',
packages=find_packages(exclude=("build", "csrc", "include", "tests")),
# Require cffi
install_requires=["cffi>=1.0.0"],
setup_requires=["cffi>=1.0.0"],
cffi_modules=[os.path.join(os.path.dirname(__file__),
'build_cffi.py:extension')],
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment