"vscode:/vscode.git/clone" did not exist on "e21946e03b8a51aad8539e358b8de83b80610430"
Commit e733e78c authored by Carl Case's avatar Carl Case Committed by Michael Carilli
Browse files

Initial support for automatic mixed precision

parent a3059288
...@@ -114,13 +114,18 @@ for i, entry in enumerate(libaten_names): ...@@ -114,13 +114,18 @@ for i, entry in enumerate(libaten_names):
aten_h = find(torch_dir, re.compile("aten.h", re.IGNORECASE).search, False) aten_h = find(torch_dir, re.compile("aten.h", re.IGNORECASE).search, False)
include_dirs = [os.path.dirname(os.path.dirname(aten_h))] torch_inc = os.path.dirname(os.path.dirname(aten_h))
include_dirs = [torch_inc]
library_dirs = [] library_dirs = []
for file in cuda_headers+headers: for file in cuda_headers+headers:
dir = os.path.dirname(file) dir = os.path.dirname(file)
if dir not in include_dirs: if dir not in include_dirs:
include_dirs.append(dir) include_dirs.append(dir)
# Object files that use the PyTorch cffi-extension interface
# They need special handling during compilation
cffi_objects = ['scale_kernel.o']
assert libaten, "Could not find PyTorch's libATen." assert libaten, "Could not find PyTorch's libATen."
assert aten_h, "Could not find PyTorch's ATen header." assert aten_h, "Could not find PyTorch's ATen header."
...@@ -178,6 +183,11 @@ def CompileCudaFiles(NVCC, CUDA_VERSION): ...@@ -178,6 +183,11 @@ def CompileCudaFiles(NVCC, CUDA_VERSION):
for dir in include_dirs: for dir in include_dirs:
nvcc_cmd.append("-I"+dir) nvcc_cmd.append("-I"+dir)
# Hack: compiling the cffi kernel code needs the TH{C}
# subdirs of include on path as well
for suffix in ['TH', 'THC']:
nvcc_cmd.append('-I{}/{}'.format(torch_inc, suffix))
for file in cuda_files: for file in cuda_files:
object_name = os.path.basename( object_name = os.path.basename(
os.path.splitext(file)[0]+".o" os.path.splitext(file)[0]+".o"
...@@ -188,8 +198,14 @@ def CompileCudaFiles(NVCC, CUDA_VERSION): ...@@ -188,8 +198,14 @@ def CompileCudaFiles(NVCC, CUDA_VERSION):
file_opts = ['-c', file, '-o', object_file] file_opts = ['-c', file, '-o', object_file]
print(' '.join(nvcc_cmd+file_opts)) extra_args = []
subprocess.check_call(nvcc_cmd+file_opts) if object_name in cffi_objects:
for module in ['TH', 'THC']:
extra_args.append('-I{}/{}'.format(torch_inc, module))
build_args = nvcc_cmd + extra_args + file_opts
print(' '.join(build_args))
subprocess.check_call(build_args)
for object_file in object_files: for object_file in object_files:
extra_link_args.append(object_file) extra_link_args.append(object_file)
...@@ -228,4 +244,10 @@ setup( ...@@ -228,4 +244,10 @@ setup(
ext_modules=[cuda_ext,], ext_modules=[cuda_ext,],
description='PyTorch Extensions written by NVIDIA', description='PyTorch Extensions written by NVIDIA',
packages=find_packages(exclude=("build", "csrc", "include", "tests")), packages=find_packages(exclude=("build", "csrc", "include", "tests")),
# Require cffi
install_requires=["cffi>=1.0.0"],
setup_requires=["cffi>=1.0.0"],
cffi_modules=[os.path.join(os.path.dirname(__file__),
'build_cffi.py:extension')],
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment