Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
50338df6
Unverified
Commit
50338df6
authored
Feb 14, 2020
by
Deyu Fu
Committed by
GitHub
Feb 14, 2020
Browse files
change include_dirs to abs path (#719)
parent
5b71d369
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
7 deletions
+10
-7
setup.py
setup.py
+10
-7
No files found.
setup.py
View file @
50338df6
...
@@ -6,6 +6,9 @@ import sys
...
@@ -6,6 +6,9 @@ import sys
import
warnings
import
warnings
import
os
import
os
# ninja build does not work unless include_dirs are abs path
this_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
if
not
torch
.
cuda
.
is_available
():
if
not
torch
.
cuda
.
is_available
():
# https://github.com/NVIDIA/apex/issues/486
# https://github.com/NVIDIA/apex/issues/486
# Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(),
# Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(),
...
@@ -148,7 +151,7 @@ if "--bnp" in sys.argv:
...
@@ -148,7 +151,7 @@ if "--bnp" in sys.argv:
'apex/contrib/csrc/groupbn/ipc.cu'
,
'apex/contrib/csrc/groupbn/ipc.cu'
,
'apex/contrib/csrc/groupbn/interface.cpp'
,
'apex/contrib/csrc/groupbn/interface.cpp'
,
'apex/contrib/csrc/groupbn/batch_norm_add_relu.cu'
],
'apex/contrib/csrc/groupbn/batch_norm_add_relu.cu'
],
include_dirs
=
[
'csrc'
],
include_dirs
=
[
os
.
path
.
join
(
this_dir
,
'csrc'
)
],
extra_compile_args
=
{
'cxx'
:
[]
+
version_dependent_macros
,
extra_compile_args
=
{
'cxx'
:
[]
+
version_dependent_macros
,
'nvcc'
:[
'-DCUDA_HAS_FP16=1'
,
'nvcc'
:[
'-DCUDA_HAS_FP16=1'
,
'-D__CUDA_NO_HALF_OPERATORS__'
,
'-D__CUDA_NO_HALF_OPERATORS__'
,
...
@@ -169,7 +172,7 @@ if "--xentropy" in sys.argv:
...
@@ -169,7 +172,7 @@ if "--xentropy" in sys.argv:
CUDAExtension
(
name
=
'xentropy_cuda'
,
CUDAExtension
(
name
=
'xentropy_cuda'
,
sources
=
[
'apex/contrib/csrc/xentropy/interface.cpp'
,
sources
=
[
'apex/contrib/csrc/xentropy/interface.cpp'
,
'apex/contrib/csrc/xentropy/xentropy_kernel.cu'
],
'apex/contrib/csrc/xentropy/xentropy_kernel.cu'
],
include_dirs
=
[
'csrc'
],
include_dirs
=
[
os
.
path
.
join
(
this_dir
,
'csrc'
)
],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
]
+
version_dependent_macros
,
extra_compile_args
=
{
'cxx'
:
[
'-O3'
]
+
version_dependent_macros
,
'nvcc'
:[
'-O3'
]
+
version_dependent_macros
}))
'nvcc'
:[
'-O3'
]
+
version_dependent_macros
}))
...
@@ -187,7 +190,7 @@ if "--deprecated_fused_adam" in sys.argv:
...
@@ -187,7 +190,7 @@ if "--deprecated_fused_adam" in sys.argv:
CUDAExtension
(
name
=
'fused_adam_cuda'
,
CUDAExtension
(
name
=
'fused_adam_cuda'
,
sources
=
[
'apex/contrib/csrc/optimizers/fused_adam_cuda.cpp'
,
sources
=
[
'apex/contrib/csrc/optimizers/fused_adam_cuda.cpp'
,
'apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu'
],
'apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu'
],
include_dirs
=
[
'csrc'
],
include_dirs
=
[
os
.
path
.
join
(
this_dir
,
'csrc'
)
],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
,]
+
version_dependent_macros
,
extra_compile_args
=
{
'cxx'
:
[
'-O3'
,]
+
version_dependent_macros
,
'nvcc'
:[
'-O3'
,
'nvcc'
:[
'-O3'
,
'--use_fast_math'
]
+
version_dependent_macros
}))
'--use_fast_math'
]
+
version_dependent_macros
}))
...
@@ -206,7 +209,7 @@ if "--fast_multihead_attn" in sys.argv:
...
@@ -206,7 +209,7 @@ if "--fast_multihead_attn" in sys.argv:
subprocess
.
run
([
"git"
,
"submodule"
,
"update"
,
"--init"
,
"apex/contrib/csrc/multihead_attn/cutlass"
])
subprocess
.
run
([
"git"
,
"submodule"
,
"update"
,
"--init"
,
"apex/contrib/csrc/multihead_attn/cutlass"
])
ext_modules
.
append
(
ext_modules
.
append
(
CUDAExtension
(
name
=
'fast_self_multihead_attn'
,
CUDAExtension
(
name
=
'fast_self_multihead_attn'
,
sources
=
[
'apex/contrib/csrc/multihead_attn/self_multihead_attn.cpp'
,
sources
=
[
'apex/contrib/csrc/multihead_attn/self_multihead_attn.cpp'
,
'apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu'
],
'apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu'
],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
,]
+
version_dependent_macros
,
extra_compile_args
=
{
'cxx'
:
[
'-O3'
,]
+
version_dependent_macros
,
'nvcc'
:[
'-O3'
,
'nvcc'
:[
'-O3'
,
...
@@ -219,7 +222,7 @@ if "--fast_multihead_attn" in sys.argv:
...
@@ -219,7 +222,7 @@ if "--fast_multihead_attn" in sys.argv:
'--use_fast_math'
]
+
version_dependent_macros
}))
'--use_fast_math'
]
+
version_dependent_macros
}))
ext_modules
.
append
(
ext_modules
.
append
(
CUDAExtension
(
name
=
'fast_self_multihead_attn_norm_add'
,
CUDAExtension
(
name
=
'fast_self_multihead_attn_norm_add'
,
sources
=
[
'apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add.cpp'
,
sources
=
[
'apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add.cpp'
,
'apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu'
],
'apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu'
],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
,]
+
version_dependent_macros
,
extra_compile_args
=
{
'cxx'
:
[
'-O3'
,]
+
version_dependent_macros
,
'nvcc'
:[
'-O3'
,
'nvcc'
:[
'-O3'
,
...
@@ -232,7 +235,7 @@ if "--fast_multihead_attn" in sys.argv:
...
@@ -232,7 +235,7 @@ if "--fast_multihead_attn" in sys.argv:
'--use_fast_math'
]
+
version_dependent_macros
}))
'--use_fast_math'
]
+
version_dependent_macros
}))
ext_modules
.
append
(
ext_modules
.
append
(
CUDAExtension
(
name
=
'fast_encdec_multihead_attn'
,
CUDAExtension
(
name
=
'fast_encdec_multihead_attn'
,
sources
=
[
'apex/contrib/csrc/multihead_attn/encdec_multihead_attn.cpp'
,
sources
=
[
'apex/contrib/csrc/multihead_attn/encdec_multihead_attn.cpp'
,
'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu'
],
'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu'
],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
,]
+
version_dependent_macros
,
extra_compile_args
=
{
'cxx'
:
[
'-O3'
,]
+
version_dependent_macros
,
'nvcc'
:[
'-O3'
,
'nvcc'
:[
'-O3'
,
...
@@ -245,7 +248,7 @@ if "--fast_multihead_attn" in sys.argv:
...
@@ -245,7 +248,7 @@ if "--fast_multihead_attn" in sys.argv:
'--use_fast_math'
]
+
version_dependent_macros
}))
'--use_fast_math'
]
+
version_dependent_macros
}))
ext_modules
.
append
(
ext_modules
.
append
(
CUDAExtension
(
name
=
'fast_encdec_multihead_attn_norm_add'
,
CUDAExtension
(
name
=
'fast_encdec_multihead_attn_norm_add'
,
sources
=
[
'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add.cpp'
,
sources
=
[
'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add.cpp'
,
'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu'
],
'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu'
],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
,]
+
version_dependent_macros
,
extra_compile_args
=
{
'cxx'
:
[
'-O3'
,]
+
version_dependent_macros
,
'nvcc'
:[
'-O3'
,
'nvcc'
:[
'-O3'
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment