Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
fb7d8f3c
Commit
fb7d8f3c
authored
Mar 11, 2024
by
xiabo
Browse files
Adapt to torch2.1
parent
e2f0eed9
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
14 additions
and
5 deletions
+14
-5
mmcv/cnn/bricks/generalized_attention.py
mmcv/cnn/bricks/generalized_attention.py
+1
-1
mmcv/version.py
mmcv/version.py
+1
-1
setup.py
setup.py
+12
-3
No files found.
mmcv/cnn/bricks/generalized_attention.py
View file @
fb7d8f3c
...
@@ -371,7 +371,7 @@ class GeneralizedAttention(nn.Module):
...
@@ -371,7 +371,7 @@ class GeneralizedAttention(nn.Module):
contiguous
().
\
contiguous
().
\
view
(
1
,
1
,
h
*
w
,
h_kv
*
w_kv
)
view
(
1
,
1
,
h
*
w
,
h_kv
*
w_kv
)
energy
=
energy
.
masked_fill_
(
cur_local_constraint_map
,
energy
=
energy
.
masked_fill_
(
cur_local_constraint_map
.
bool
()
,
float
(
'-inf'
))
float
(
'-inf'
))
attention
=
F
.
softmax
(
energy
,
3
)
attention
=
F
.
softmax
(
energy
,
3
)
...
...
mmcv/version.py
View file @
fb7d8f3c
...
@@ -32,4 +32,4 @@ def parse_version_info(version_str: str, length: int = 4) -> tuple:
...
@@ -32,4 +32,4 @@ def parse_version_info(version_str: str, length: int = 4) -> tuple:
version_info
=
tuple
(
int
(
x
)
for
x
in
__version__
.
split
(
'.'
)[:
3
])
version_info
=
tuple
(
int
(
x
)
for
x
in
__version__
.
split
(
'.'
)[:
3
])
__all__
=
[
'__version__'
,
'version_info'
,
'parse_version_info'
]
__all__
=
[
'__version__'
,
'__dcu_version__'
,
'version_info'
,
'parse_version_info'
]
setup.py
View file @
fb7d8f3c
...
@@ -257,13 +257,19 @@ def get_extensions():
...
@@ -257,13 +257,19 @@ def get_extensions():
extra_compile_args
=
{
'cxx'
:
[]}
extra_compile_args
=
{
'cxx'
:
[]}
if
platform
.
system
()
!=
'Windows'
:
if
platform
.
system
()
!=
'Windows'
:
extra_compile_args
[
'cxx'
]
=
[
'-std=c++14'
]
if
parse_version
(
torch
.
__version__
)
<=
parse_version
(
'1.12.1'
):
extra_compile_args
[
'cxx'
]
=
[
'-std=c++14'
]
else
:
extra_compile_args
[
'cxx'
]
=
[
'-std=c++17'
]
else
:
else
:
# TODO: In Windows, C++17 is chosen to compile extensions in
# TODO: In Windows, C++17 is chosen to compile extensions in
# PyTorch2.0 , but a compile error will be reported.
# PyTorch2.0 , but a compile error will be reported.
# As a temporary solution, force the use of C++14.
# As a temporary solution, force the use of C++14.
if
parse_version
(
torch
.
__version__
)
>
=
parse_version
(
'
2.0.0
'
):
if
parse_version
(
torch
.
__version__
)
<
=
parse_version
(
'
1.12.1
'
):
extra_compile_args
[
'cxx'
]
=
[
'/std:c++14'
]
extra_compile_args
[
'cxx'
]
=
[
'/std:c++14'
]
else
:
extra_compile_args
[
'cxx'
]
=
[
'/std:c++17'
]
include_dirs
=
[]
include_dirs
=
[]
library_dirs
=
[]
library_dirs
=
[]
...
@@ -477,7 +483,10 @@ def get_extensions():
...
@@ -477,7 +483,10 @@ def get_extensions():
# to compile those cpp files, so there is no need to add the
# to compile those cpp files, so there is no need to add the
# argument
# argument
if
'nvcc'
in
extra_compile_args
and
platform
.
system
()
!=
'Windows'
:
if
'nvcc'
in
extra_compile_args
and
platform
.
system
()
!=
'Windows'
:
extra_compile_args
[
'nvcc'
]
+=
[
'-std=c++14'
]
if
parse_version
(
torch
.
__version__
)
<=
parse_version
(
'1.12.1'
):
extra_compile_args
[
'nvcc'
]
+=
[
'-std=c++14'
]
else
:
extra_compile_args
[
'nvcc'
]
+=
[
'-std=c++17'
]
ext_ops
=
extension
(
ext_ops
=
extension
(
name
=
ext_name
,
name
=
ext_name
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment