Commit c40bf687 authored by zhuwenwen's avatar zhuwenwen
Browse files

support dtk2304

parent ed252f75
...@@ -32,14 +32,14 @@ ...@@ -32,14 +32,14 @@
__inline__ __device__ float WarpAllReduceMax(float val) { __inline__ __device__ float WarpAllReduceMax(float val) {
for (int mask = 1; mask < 32; mask *= 2) { for (int mask = 1; mask < 32; mask *= 2) {
val = max(val, __shfl_xor_sync(0xffffffff, val, mask)); val = max(val, __shfl_xorc(val, mask));
} }
return val; return val;
} }
__inline__ __device__ float WarpAllReduceSum(float val) { __inline__ __device__ float WarpAllReduceSum(float val) {
for (int mask = 1; mask < 32; mask *= 2) { for (int mask = 1; mask < 32; mask *= 2) {
val += __shfl_xor_sync(0xffffffff, val, mask); val += __shfl_xor( val, mask);
} }
return val; return val;
} }
......
...@@ -32,30 +32,28 @@ version_dependent_macros = [ ...@@ -32,30 +32,28 @@ version_dependent_macros = [
extra_cuda_flags = [ extra_cuda_flags = [
'-std=c++14', '-std=c++14',
#'-maxrregcount=50', #'-maxrregcount=50',
#'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_OPERATORS__',
#'-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF_CONVERSIONS__'
'--expt-relaxed-constexpr',
'--expt-extended-lambda'
] ]
cc_flag = ['-DAMDGPU_TARGETS', 'gfx900;gfx906'] cc_flag = ['-DAMDGPU_TARGETS', 'gfx900;gfx906;gfx926']
extra_cuda_flags += cc_flag extra_cuda_flags += cc_flag
def get_sha(pytorch_root: Union[str, Path]) -> str: def get_sha(root: Union[str, Path]) -> str:
try: try:
return subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=pytorch_root).decode('ascii').strip() return subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=root).decode('ascii').strip()
except Exception: except Exception:
return 'Unknown' return 'Unknown'
def get_version_add(sha: Optional[str] = None) -> str: def get_version_add(sha: Optional[str] = None) -> str:
openfold_root = os.path.dirname(os.path.abspath(__file__))
add_version_path = "version.py" add_version_path = "version.py"
if sha != 'Unknown': if sha != 'Unknown':
if sha is None: if sha is None:
sha_path = os.getenv('OPENFOLD_DOWNLOAD_PATH', "") sha = get_sha(openfold_root)
sha = get_sha(sha_path)
version = 'git' + sha[:7] version = 'git' + sha[:7]
if os.getenv('OPENFOLD_BUILD_VERSION'): if os.getenv('OPENFOLD_BUILD_VERSION'):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment