install.sh 1.77 KB
Newer Older
1

2
3
4
5
6
7
8
#!/usr/bin/env bash

unset PYTORCH_VERSION
# For unittest, nightly PyTorch is used as the following section,
# so no need to set PYTORCH_VERSION.
# In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config.

9
set -ex
10
11
12
13
14
15

this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"

eval "$(./conda/Scripts/conda.exe 'shell.bash' 'hook')"
conda activate ./env

16
# TODO, refactor the below logic to make it easy to understand how to get correct cuda_version.
guyang3532's avatar
guyang3532 committed
17
if [ "${CU_VERSION:-}" == cpu ] ; then
18
    cudatoolkit="cpuonly"
19
    version="cpu"
20
else
guyang3532's avatar
guyang3532 committed
21
22
23
24
25
    if [[ ${#CU_VERSION} -eq 4 ]]; then
        CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}"
    elif [[ ${#CU_VERSION} -eq 5 ]]; then
        CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}"
    fi
26
27

    cuda_toolkit_pckg="cudatoolkit"
ptrblck's avatar
ptrblck committed
28
    if [[ $CUDA_VERSION == 11.6 || $CUDA_VERSION == 11.7 || $CUDA_VERSION == 11.8 || $CUDA_VERSION == 12.1 ]]; then
Andrey Talman's avatar
Andrey Talman committed
29
        cuda_toolkit_pckg="pytorch-cuda"
30
31
    fi

guyang3532's avatar
guyang3532 committed
32
    echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION"
33
    version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")"
34
    cudatoolkit="${cuda_toolkit_pckg}=${version}"
35
fi
36

37
printf "Installing PyTorch with %s\n" "${cudatoolkit}"
38
conda install -y -c "pytorch-${UPLOAD_CHANNEL}" -c nvidia "pytorch-${UPLOAD_CHANNEL}"::pytorch[build="*${version}*"] "${cudatoolkit}"
39

40
41
42
43
44
45
46
47
48
49
50
51
torch_cuda=$(python -c "import torch; print(torch.cuda.is_available())")
echo torch.cuda.is_available is $torch_cuda

if [ ! -z "${CUDA_VERSION:-}" ] ; then
    if [ "$torch_cuda" == "False" ]; then
        echo "torch with cuda installed but torch.cuda.is_available() is False"
        exit 1
    fi
fi

source "$this_dir/set_cuda_envs.sh"

52
printf "* Installing torchvision\n"
53
"$this_dir/vc_env_helper.bat" python setup.py develop