Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
d89f5e66
Unverified
Commit
d89f5e66
authored
Apr 18, 2022
by
Masaki Kozuki
Committed by
GitHub
Apr 18, 2022
Browse files
[submodule update] Bump cudnn-frontend to v0.6.1 (#1353)
* bump version * add guard * fix the cond
parent
727a6452
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
9 deletions
+17
-9
apex/contrib/csrc/cudnn-frontend
apex/contrib/csrc/cudnn-frontend
+1
-1
setup.py
setup.py
+16
-8
No files found.
cudnn-frontend
@
fa611998
Compare
7b83dba8
...
fa611998
Subproject commit
7b83dba83fa31381aeca508d89aab94f4639ac6d
Subproject commit
fa611998a360cbabaa2dcc7c9859748144114fc0
setup.py
View file @
d89f5e66
...
@@ -58,6 +58,13 @@ def append_nvcc_threads(nvcc_extra_args):
...
@@ -58,6 +58,13 @@ def append_nvcc_threads(nvcc_extra_args):
return
nvcc_extra_args
return
nvcc_extra_args
def
check_cudnn_version_and_warn
(
global_option
:
str
,
required_cudnn_version
:
int
)
->
bool
:
green
=
torch
.
backends
.
cudnn
.
is_available
()
and
torch
.
backends
.
cudnn
.
version
()
>=
required_cudnn_version
if
not
green
:
warnings
.
warn
(
f
"Skip `
{
global_option
}
` as it requires cuDNN
{
required_cudnn_version
}
or later"
)
return
green
if
not
torch
.
cuda
.
is_available
():
if
not
torch
.
cuda
.
is_available
():
# https://github.com/NVIDIA/apex/issues/486
# https://github.com/NVIDIA/apex/issues/486
# Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(),
# Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(),
...
@@ -649,15 +656,16 @@ if "--fast_bottleneck" in sys.argv:
...
@@ -649,15 +656,16 @@ if "--fast_bottleneck" in sys.argv:
if
"--fused_conv_bias_relu"
in
sys
.
argv
:
if
"--fused_conv_bias_relu"
in
sys
.
argv
:
sys
.
argv
.
remove
(
"--fused_conv_bias_relu"
)
sys
.
argv
.
remove
(
"--fused_conv_bias_relu"
)
raise_if_cuda_home_none
(
"--fused_conv_bias_relu"
)
raise_if_cuda_home_none
(
"--fused_conv_bias_relu"
)
subprocess
.
run
([
"git"
,
"submodule"
,
"update"
,
"--init"
,
"apex/contrib/csrc/cudnn-frontend/"
])
if
check_cudnn_version_and_warn
(
"--fused_conv_bias_relu"
,
8400
):
ext_modules
.
append
(
subprocess
.
run
([
"git"
,
"submodule"
,
"update"
,
"--init"
,
"apex/contrib/csrc/cudnn-frontend/"
])
CUDAExtension
(
ext_modules
.
append
(
name
=
"fused_conv_bias_relu"
,
CUDAExtension
(
sources
=
[
"apex/contrib/csrc/conv_bias_relu/conv_bias_relu.cpp"
],
name
=
"fused_conv_bias_relu"
,
include_dirs
=
[
os
.
path
.
join
(
this_dir
,
"apex/contrib/csrc/cudnn-frontend/include"
)],
sources
=
[
"apex/contrib/csrc/conv_bias_relu/conv_bias_relu.cpp"
],
extra_compile_args
=
{
"cxx"
:
[
"-O3"
]
+
version_dependent_macros
+
generator_flag
},
include_dirs
=
[
os
.
path
.
join
(
this_dir
,
"apex/contrib/csrc/cudnn-frontend/include"
)],
extra_compile_args
=
{
"cxx"
:
[
"-O3"
]
+
version_dependent_macros
+
generator_flag
},
)
)
)
)
setup
(
setup
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment