patch_xformers.rocm.sh 1.48 KB
Newer Older
1
#!/bin/bash
2
3
4
5
6
7
8
9
10
11
12
set -e

XFORMERS_VERSION="0.0.23"

export XFORMERS_INSTALLED_VERSION=$(python -c 'import xformers; print(xformers.__version__)')

if [ "$XFORMERS_INSTALLED_VERSION" != "$XFORMERS_VERSION" ]; then
    echo "ERROR: xformers version must be ${XFORMERS_VERSION}. ${XFORMERS_INSTALLED_VERSION} is installed"
    exit 1
fi

13
14
15
export XFORMERS_FMHA_FLASH_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.flash.__file__)')
export XFORMERS_FMHA_COMMON_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.common.__file__)')

16
17
echo "XFORMERS_FMHA_FLASH_PATH = ${XFORMERS_FMHA_FLASH_PATH}"
echo "XFORMERS_FMHA_COMMON_PATH = ${XFORMERS_FMHA_COMMON_PATH}"
18

19
if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-${XFORMERS_VERSION}.rocm.patch"; then
20
    echo "Applying patch to ${XFORMERS_FMHA_FLASH_PATH}"
21
    patch -p0 $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-${XFORMERS_VERSION}.rocm.patch"
22
23
24
25
26
    echo "Successfully patch ${XFORMERS_FMHA_FLASH_PATH}"
else
    echo "${XFORMERS_FMHA_FLASH_PATH} was patched before"
fi

27
if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-${XFORMERS_VERSION}.rocm.patch"; then
28
    echo "Applying patch to ${XFORMERS_FMHA_COMMON_PATH}"
29
    patch -p0 $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-${XFORMERS_VERSION}.rocm.patch"
30
31
32
33
    echo "Successfully patch ${XFORMERS_FMHA_COMMON_PATH}"
else
    echo "${XFORMERS_FMHA_COMMON_PATH} was patched before"
fi