Commit fb7d8f3c authored by xiabo's avatar xiabo
Browse files

Adapt to torch2.1

parent e2f0eed9
......@@ -371,7 +371,7 @@ class GeneralizedAttention(nn.Module):
contiguous().\
view(1, 1, h*w, h_kv*w_kv)
energy = energy.masked_fill_(cur_local_constraint_map,
energy = energy.masked_fill_(cur_local_constraint_map.bool(),
float('-inf'))
attention = F.softmax(energy, 3)
......
......@@ -32,4 +32,4 @@ def parse_version_info(version_str: str, length: int = 4) -> tuple:
version_info = tuple(int(x) for x in __version__.split('.')[:3])
__all__ = ['__version__', 'version_info', 'parse_version_info']
__all__ = ['__version__', '__dcu_version__', 'version_info', 'parse_version_info']
......@@ -257,13 +257,19 @@ def get_extensions():
extra_compile_args = {'cxx': []}
if platform.system() != 'Windows':
if parse_version(torch.__version__) <= parse_version('1.12.1'):
extra_compile_args['cxx'] = ['-std=c++14']
else:
extra_compile_args['cxx'] = ['-std=c++17']
else:
# TODO: In Windows, C++17 is chosen to compile extensions in
# PyTorch2.0 , but a compile error will be reported.
# As a temporary solution, force the use of C++14.
if parse_version(torch.__version__) >= parse_version('2.0.0'):
if parse_version(torch.__version__) <= parse_version('1.12.1'):
extra_compile_args['cxx'] = ['/std:c++14']
else:
extra_compile_args['cxx'] = ['/std:c++17']
include_dirs = []
library_dirs = []
......@@ -477,7 +483,10 @@ def get_extensions():
# to compile those cpp files, so there is no need to add the
# argument
if 'nvcc' in extra_compile_args and platform.system() != 'Windows':
if parse_version(torch.__version__) <= parse_version('1.12.1'):
extra_compile_args['nvcc'] += ['-std=c++14']
else:
extra_compile_args['nvcc'] += ['-std=c++17']
ext_ops = extension(
name=ext_name,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment