Commit fb7d8f3c authored by xiabo's avatar xiabo
Browse files

Adapt to torch2.1

parent e2f0eed9
...@@ -371,7 +371,7 @@ class GeneralizedAttention(nn.Module): ...@@ -371,7 +371,7 @@ class GeneralizedAttention(nn.Module):
contiguous().\ contiguous().\
view(1, 1, h*w, h_kv*w_kv) view(1, 1, h*w, h_kv*w_kv)
energy = energy.masked_fill_(cur_local_constraint_map, energy = energy.masked_fill_(cur_local_constraint_map.bool(),
float('-inf')) float('-inf'))
attention = F.softmax(energy, 3) attention = F.softmax(energy, 3)
......
...@@ -32,4 +32,4 @@ def parse_version_info(version_str: str, length: int = 4) -> tuple: ...@@ -32,4 +32,4 @@ def parse_version_info(version_str: str, length: int = 4) -> tuple:
version_info = tuple(int(x) for x in __version__.split('.')[:3]) version_info = tuple(int(x) for x in __version__.split('.')[:3])
__all__ = ['__version__', 'version_info', 'parse_version_info'] __all__ = ['__version__', '__dcu_version__', 'version_info', 'parse_version_info']
...@@ -257,13 +257,19 @@ def get_extensions(): ...@@ -257,13 +257,19 @@ def get_extensions():
extra_compile_args = {'cxx': []} extra_compile_args = {'cxx': []}
if platform.system() != 'Windows': if platform.system() != 'Windows':
extra_compile_args['cxx'] = ['-std=c++14'] if parse_version(torch.__version__) <= parse_version('1.12.1'):
extra_compile_args['cxx'] = ['-std=c++14']
else:
extra_compile_args['cxx'] = ['-std=c++17']
else: else:
# TODO: In Windows, C++17 is chosen to compile extensions in # TODO: In Windows, C++17 is chosen to compile extensions in
# PyTorch2.0 , but a compile error will be reported. # PyTorch2.0 , but a compile error will be reported.
# As a temporary solution, force the use of C++14. # As a temporary solution, force the use of C++14.
if parse_version(torch.__version__) >= parse_version('2.0.0'): if parse_version(torch.__version__) <= parse_version('1.12.1'):
extra_compile_args['cxx'] = ['/std:c++14'] extra_compile_args['cxx'] = ['/std:c++14']
else:
extra_compile_args['cxx'] = ['/std:c++17']
include_dirs = [] include_dirs = []
library_dirs = [] library_dirs = []
...@@ -477,7 +483,10 @@ def get_extensions(): ...@@ -477,7 +483,10 @@ def get_extensions():
# to compile those cpp files, so there is no need to add the # to compile those cpp files, so there is no need to add the
# argument # argument
if 'nvcc' in extra_compile_args and platform.system() != 'Windows': if 'nvcc' in extra_compile_args and platform.system() != 'Windows':
extra_compile_args['nvcc'] += ['-std=c++14'] if parse_version(torch.__version__) <= parse_version('1.12.1'):
extra_compile_args['nvcc'] += ['-std=c++14']
else:
extra_compile_args['nvcc'] += ['-std=c++17']
ext_ops = extension( ext_ops = extension(
name=ext_name, name=ext_name,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment