"deploy/git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "6eb2c4ab7e463913410911309691b83ee0ef1119"
Commit 844912dc authored by Tri Dao's avatar Tri Dao
Browse files

[CI] Switch from CUDA 12.2 to 12.3

parent 40e534a7
...@@ -45,7 +45,7 @@ jobs: ...@@ -45,7 +45,7 @@ jobs:
os: [ubuntu-20.04] os: [ubuntu-20.04]
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12'] python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
torch-version: ['2.0.1', '2.1.2', '2.2.2', '2.3.1', '2.4.0.dev20240514'] torch-version: ['2.0.1', '2.1.2', '2.2.2', '2.3.1', '2.4.0.dev20240514']
cuda-version: ['11.8.0', '12.2.2'] cuda-version: ['11.8.0', '12.3.2']
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
# Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs) # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
...@@ -60,7 +60,7 @@ jobs: ...@@ -60,7 +60,7 @@ jobs:
python-version: '3.12' python-version: '3.12'
# Pytorch <= 2.0 only supports CUDA <= 11.8 # Pytorch <= 2.0 only supports CUDA <= 11.8
- torch-version: '2.0.1' - torch-version: '2.0.1'
cuda-version: '12.2.2' cuda-version: '12.3.2'
steps: steps:
- name: Checkout - name: Checkout
...@@ -145,8 +145,8 @@ jobs: ...@@ -145,8 +145,8 @@ jobs:
export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
# Limit MAX_JOBS otherwise the github runner goes OOM # Limit MAX_JOBS otherwise the github runner goes OOM
# CUDA 11.8 can compile with 2 jobs, but CUDA 12.2 goes OOM # CUDA 11.8 can compile with 2 jobs, but CUDA 12.3 goes OOM
MAX_JOBS=$([ "$MATRIX_CUDA_VERSION" == "122" ] && echo 1 || echo 2) FLASH_ATTENTION_FORCE_BUILD="TRUE" FLASH_ATTENTION_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist MAX_JOBS=$([ "$MATRIX_CUDA_VERSION" == "123" ] && echo 1 || echo 2) FLASH_ATTENTION_FORCE_BUILD="TRUE" FLASH_ATTENTION_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist
tmpname=cu${MATRIX_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ matrix.cxx11_abi }} tmpname=cu${MATRIX_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ matrix.cxx11_abi }}
wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2") wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
ls dist/*whl |xargs -I {} mv {} dist/${wheel_name} ls dist/*whl |xargs -I {} mv {} dist/${wheel_name}
......
...@@ -269,9 +269,9 @@ def get_wheel_url(): ...@@ -269,9 +269,9 @@ def get_wheel_url():
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
torch_cuda_version = parse(torch.version.cuda) torch_cuda_version = parse(torch.version.cuda)
torch_version_raw = parse(torch.__version__) torch_version_raw = parse(torch.__version__)
# For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2 # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3
# to save CI time. Minor versions should be compatible. # to save CI time. Minor versions should be compatible.
torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2") torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.3")
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
platform_name = get_platform() platform_name = get_platform()
flash_version = get_package_version() flash_version = get_package_version()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment