Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
e733e78c
Commit
e733e78c
authored
May 16, 2018
by
Carl Case
Committed by
Michael Carilli
May 17, 2018
Browse files
Initial support for automatic mixed precision
parent
a3059288
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
5 deletions
+27
-5
setup.py
setup.py
+27
-5
No files found.
setup.py
View file @
e733e78c
...
...
@@ -114,13 +114,18 @@ for i, entry in enumerate(libaten_names):
aten_h
=
find
(
torch_dir
,
re
.
compile
(
"aten.h"
,
re
.
IGNORECASE
).
search
,
False
)
include_dirs
=
[
os
.
path
.
dirname
(
os
.
path
.
dirname
(
aten_h
))]
torch_inc
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
aten_h
))
include_dirs
=
[
torch_inc
]
library_dirs
=
[]
for
file
in
cuda_headers
+
headers
:
dir
=
os
.
path
.
dirname
(
file
)
if
dir
not
in
include_dirs
:
include_dirs
.
append
(
dir
)
# Object files that use the PyTorch cffi-extension interface
# They need special handling during compilation
cffi_objects
=
[
'scale_kernel.o'
]
assert
libaten
,
"Could not find PyTorch's libATen."
assert
aten_h
,
"Could not find PyTorch's ATen header."
...
...
@@ -178,18 +183,29 @@ def CompileCudaFiles(NVCC, CUDA_VERSION):
for
dir
in
include_dirs
:
nvcc_cmd
.
append
(
"-I"
+
dir
)
# Hack: compiling the cffi kernel code needs the TH{C}
# subdirs of include on path as well
for
suffix
in
[
'TH'
,
'THC'
]:
nvcc_cmd
.
append
(
'-I{}/{}'
.
format
(
torch_inc
,
suffix
))
for
file
in
cuda_files
:
object_name
=
os
.
path
.
basename
(
os
.
path
.
splitext
(
file
)[
0
]
+
".o"
)
object_file
=
os
.
path
.
join
(
buildir
,
object_name
)
object_files
.
append
(
object_file
)
file_opts
=
[
'-c'
,
file
,
'-o'
,
object_file
]
print
(
' '
.
join
(
nvcc_cmd
+
file_opts
))
subprocess
.
check_call
(
nvcc_cmd
+
file_opts
)
extra_args
=
[]
if
object_name
in
cffi_objects
:
for
module
in
[
'TH'
,
'THC'
]:
extra_args
.
append
(
'-I{}/{}'
.
format
(
torch_inc
,
module
))
build_args
=
nvcc_cmd
+
extra_args
+
file_opts
print
(
' '
.
join
(
build_args
))
subprocess
.
check_call
(
build_args
)
for
object_file
in
object_files
:
extra_link_args
.
append
(
object_file
)
...
...
@@ -228,4 +244,10 @@ setup(
ext_modules
=
[
cuda_ext
,],
description
=
'PyTorch Extensions written by NVIDIA'
,
packages
=
find_packages
(
exclude
=
(
"build"
,
"csrc"
,
"include"
,
"tests"
)),
# Require cffi
install_requires
=
[
"cffi>=1.0.0"
],
setup_requires
=
[
"cffi>=1.0.0"
],
cffi_modules
=
[
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'build_cffi.py:extension'
)],
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment