Unverified Commit d1a63639 authored by Jinze (Richard) Xue's avatar Jinze (Richard) Xue Committed by GitHub
Browse files

Fix --cuaev-all-sms (#600)

* Fix --cuaev-all-sms

* expecttest

* fix ci
parent bf771af0
...@@ -2,6 +2,6 @@ ...@@ -2,6 +2,6 @@
pip install --upgrade pip pip install --upgrade pip
pip install twine wheel pip install twine wheel
pip install --pre torch -f https://download.pytorch.org/whl/nightly/cu110/torch_nightly.html pip install --pre torch -f https://download.pytorch.org/whl/nightly/cu112/torch_nightly.html
pip install -r test_requirements.txt pip install -r test_requirements.txt
pip install -r docs_requirements.txt pip install -r docs_requirements.txt
...@@ -53,10 +53,9 @@ def maybe_download_cub(): ...@@ -53,10 +53,9 @@ def maybe_download_cub():
def cuda_extension(build_all=False): def cuda_extension(build_all=False):
import torch import torch
from torch.utils.cpp_extension import CUDAExtension from torch.utils.cpp_extension import CUDAExtension
SMs = None SMs = []
print('-' * 75) print('-' * 75)
if not build_all: if not build_all:
SMs = []
devices = torch.cuda.device_count() devices = torch.cuda.device_count()
print('FAST_BUILD_CUAEV: ON') print('FAST_BUILD_CUAEV: ON')
print('This build will only support the following devices or the devices with same cuda capability: ') print('This build will only support the following devices or the devices with same cuda capability: ')
...@@ -71,13 +70,13 @@ def cuda_extension(build_all=False): ...@@ -71,13 +70,13 @@ def cuda_extension(build_all=False):
SMs.append(sm) SMs.append(sm)
nvcc_args = ["-Xptxas=-v", '--expt-extended-lambda', '-use_fast_math'] nvcc_args = ["-Xptxas=-v", '--expt-extended-lambda', '-use_fast_math']
if SMs: if SMs and not ONLY_BUILD_SM80:
for sm in SMs: for sm in SMs:
nvcc_args.append(f"-gencode=arch=compute_{sm},code=sm_{sm}") nvcc_args.append(f"-gencode=arch=compute_{sm},code=sm_{sm}")
elif len(SMs) == 0 and ONLY_BUILD_SM80: # --cuaev --only-sm80 elif ONLY_BUILD_SM80: # --cuaev --only-sm80
nvcc_args.append("-gencode=arch=compute_80,code=sm_80") nvcc_args.append("-gencode=arch=compute_80,code=sm_80")
else: # no gpu detected else: # no gpu detected
print('NO gpu detected, will build for all SMs') print('Will build for all SMs')
nvcc_args.append("-gencode=arch=compute_60,code=sm_60") nvcc_args.append("-gencode=arch=compute_60,code=sm_60")
nvcc_args.append("-gencode=arch=compute_61,code=sm_61") nvcc_args.append("-gencode=arch=compute_61,code=sm_61")
nvcc_args.append("-gencode=arch=compute_70,code=sm_70") nvcc_args.append("-gencode=arch=compute_70,code=sm_70")
......
...@@ -6,3 +6,4 @@ pillow ...@@ -6,3 +6,4 @@ pillow
pkbar pkbar
pyyaml pyyaml
pytest pytest
expecttest
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment