Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
b10621d1
Commit
b10621d1
authored
Nov 08, 2022
by
flyingdown
Browse files
修改setup.py,修复编译错误,适配dtk-22.10
parent
86dfa18d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
5 deletions
+15
-5
csrc/multi_tensor_l2norm_kernel_mp.cu
csrc/multi_tensor_l2norm_kernel_mp.cu
+5
-1
setup.py
setup.py
+10
-4
No files found.
csrc/multi_tensor_l2norm_kernel_mp.cu
View file @
b10621d1
...
...
@@ -109,7 +109,11 @@ struct L2NormFunctor
}
};
__global__
void
cleanup
(
__global__
void
#ifdef __HIP_PLATFORM_HCC__
__launch_bounds__
(
1024
)
#endif
cleanup
(
float
*
output
,
float
*
output_per_tensor
,
float
*
ret
,
...
...
setup.py
View file @
b10621d1
import
torch
from
torch.utils.cpp_extension
import
BuildExtension
,
CppExtension
,
CUDAExtension
,
CUDA_HOME
from
torch.utils.cpp_extension
import
BuildExtension
,
CppExtension
,
CUDAExtension
,
CUDA_HOME
,
ROCM_HOME
from
setuptools
import
setup
,
find_packages
import
subprocess
...
...
@@ -275,6 +275,7 @@ if "--cuda_ext" in sys.argv:
CUDAExtension
(
name
=
'fused_dense_cuda'
,
sources
=
[
'csrc/fused_dense.cpp'
,
'csrc/fused_dense_cuda.cu'
],
include_dirs
=
[
os
.
path
.
join
(
this_dir
,
'csrc'
)],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
]
+
version_dependent_macros
,
'nvcc'
:[
'-O3'
]
+
version_dependent_macros
}))
nvcc_args_transformer
=
[
'-O3'
,
...
...
@@ -522,8 +523,8 @@ if "--fast_multihead_attn" in sys.argv or "--cuda_ext" in sys.argv:
'--use_fast_math'
]
+
version_dependent_macros
+
generator_flag
+
cc_flag
hipcc_args_mha
=
[
'-O3'
,
'-Iapex/contrib/csrc/multihead_attn/cutlass'
,
'-I
/opt/rocm/
include/hiprand'
,
'-I
/opt/rocm/
include/rocrand'
,
'-I
'
+
os
.
path
.
join
(
ROCM_HOME
,
'
include/hiprand'
)
,
'-I
'
+
os
.
path
.
join
(
ROCM_HOME
,
'
include/rocrand'
)
,
'-U__HIP_NO_HALF_OPERATORS__'
,
'-U__HIP_NO_HALF_CONVERSIONS__'
]
+
version_dependent_macros
+
generator_flag
if
found_Backward_Pass_Guard
:
...
...
@@ -559,6 +560,9 @@ if "--transducer" in sys.argv or "--cuda_ext" in sys.argv:
if
not
IS_ROCM_PYTORCH
:
raise_if_cuda_home_none
(
"--transducer"
)
hipcc_args_mha
=
[
'-O3'
,
'-I'
+
os
.
path
.
join
(
ROCM_HOME
,
'include/hiprand'
),
'-I'
+
os
.
path
.
join
(
ROCM_HOME
,
'include/rocrand'
),]
+
version_dependent_macros
+
generator_flag
ext_modules
.
append
(
CUDAExtension
(
name
=
"transducer_joint_cuda"
,
...
...
@@ -569,7 +573,7 @@ if "--transducer" in sys.argv or "--cuda_ext" in sys.argv:
extra_compile_args
=
{
"cxx"
:
[
"-O3"
]
+
version_dependent_macros
+
generator_flag
,
"nvcc"
:
append_nvcc_threads
([
"-O3"
]
+
version_dependent_macros
+
generator_flag
)
if
not
IS_ROCM_PYTORCH
else
[
"-O3"
]
+
version_dependent_macros
+
generator_flag
,
else
hipcc_args_mha
,
},
include_dirs
=
[
os
.
path
.
join
(
this_dir
,
"csrc"
),
os
.
path
.
join
(
this_dir
,
"apex/contrib/csrc/multihead_attn"
)],
)
...
...
@@ -619,6 +623,7 @@ if "--peer_memory" in sys.argv or "--cuda_ext" in sys.argv:
"apex/contrib/csrc/peer_memory/peer_memory_cuda.cu"
,
"apex/contrib/csrc/peer_memory/peer_memory.cpp"
,
],
include_dirs
=
[
os
.
path
.
join
(
this_dir
,
"apex/contrib/csrc/nccl_p2p"
)],
extra_compile_args
=
{
"cxx"
:
[
"-O3"
]
+
version_dependent_macros
+
generator_flag
},
)
)
...
...
@@ -637,6 +642,7 @@ if "--nccl_p2p" in sys.argv or "--cuda_ext" in sys.argv:
"apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu"
,
"apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp"
,
],
include_dirs
=
[
os
.
path
.
join
(
this_dir
,
"apex/contrib/csrc/nccl_p2p"
)],
extra_compile_args
=
{
"cxx"
:
[
"-O3"
]
+
version_dependent_macros
+
generator_flag
},
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment