Commit 1c8a0898 authored by rusty1s's avatar rusty1s
Browse files

fix windows

parent feb2a136
......@@ -64,6 +64,14 @@ else
export FORCE_CUDA=1
fi
if [ "${IDX}" == "cu100" ] || [ "${IDX}" == "cu101" ]; then
export NVCC_FLAGS="-gencode=arch=compute_35,code=sm_35 -gencode=arch=compute_50,code=sm_50 -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_50,code=compute_50"
fi
if [ "${IDX}" == "cu92" ]; then
export NVCC_FLAGS="-gencode=arch=compute_35,code=sm_35 -gencode=arch=compute_50,code=sm_50 -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_50,code=compute_50"
fi
if [ "${TRAVIS_OS_NAME}" = "linux" ] && [ "${IDX}" != "cpu" ]; then
INSTALLER=cuda-repo-${UBUNTU_VERSION}_${CUDA}_amd64.deb
wget -nv "http://developer.download.nvidia.com/compute/cuda/repos/${UBUNTU_VERSION}/x86_64/${INSTALLER}"
......@@ -82,10 +90,15 @@ fi
if [ "${TRAVIS_OS_NAME}" = "windows" ] && [ "${IDX}" != "cpu" ]; then
wget -nv "${CUDA_URL}/${CUDA_FILE}"
PowerShell -Command "Start-Process -FilePath \"${CUDA_FILE}\" -ArgumentList \"-s nvcc_${CUDA_SHORT} cublas_dev_${CUDA_SHORT} cusparse_dev_${CUDA_SHORT}\" -Wait -NoNewWindow"
ls /c/Program\ Files\ \(x86\)/Microsoft\ Visual\ Studio/2017/BuildTools/VC/Tools/MSVC
CUDA_HOME=/c/Program\ Files/NVIDIA\ GPU\ Computing\ Toolkit/CUDA/v${CUDA_SHORT}
PATH=${CUDA_HOME}/bin:$PATH
PATH=/c/Program\ Files\ \(x86\)/Microsoft\ Visual\ Studio/2017/BuildTools/MSBuild/15.0/Bin:$PATH
nvcc --version
cat -n /c/Program\ Files/NVIDIA\ GPU\ Computing\ Toolkit/CUDA/v${CUDA_SHORT}/include/crt/host_config.h
fi
# Fix Cuda9.2 on Windows: https://github.com/pytorch/pytorch/issues/6109
if [ "${TRAVIS_OS_NAME}" = "windows" ] && [ "${IDX}" = "cu92" ]; then
cat -n "${CUDA_PATH}/include/crt/host_config.h"
sed -i.bak -e '129,141d' "${CUDA_PATH}/include/crt/host_config.h"
cat -n "${CUDA_PATH}/include/crt/host_config.h"
fi
......@@ -26,7 +26,7 @@ def get_extensions():
define_macros += [('WITH_CUDA', None)]
nvcc_flags = os.getenv('NVCC_FLAGS', '')
nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ')
nvcc_flags += ['-arch=sm_35', '--expt-relaxed-constexpr']
nvcc_flags += ['--expt-relaxed-constexpr']
extra_compile_args['nvcc'] = nvcc_flags
extensions_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'csrc')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment