Unverified Commit f375ec84 authored by TJian's avatar TJian Committed by GitHub
Browse files

[ROCm] Upgrade xformers version for ROCm & update doc (#2079)


Co-authored-by: default avatarmiloice <jeffaw99@hotmail.com>
parent 518369d7
...@@ -47,12 +47,12 @@ RUN mkdir libs \ ...@@ -47,12 +47,12 @@ RUN mkdir libs \
COPY ./ /app/vllm COPY ./ /app/vllm
RUN python3 -m pip install --upgrade pip RUN python3 -m pip install --upgrade pip
RUN pip install xformers==0.0.22.post7 --no-deps RUN pip install xformers==0.0.23 --no-deps
RUN cd /app \ RUN cd /app \
&& cd vllm \ && cd vllm \
&& pip install -U -r requirements-rocm.txt \ && pip install -U -r requirements-rocm.txt \
&& bash patch_xformers-0.0.22.post7.rocm.sh \ && bash patch_xformers-0.0.23.rocm.sh \
&& python3 setup.py install \ && python3 setup.py install \
&& cd .. && cd ..
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
Installation with ROCm Installation with ROCm
====================== ======================
vLLM 0.2.x onwards supports model inferencing and serving on AMD GPUs with ROCm. vLLM 0.2.4 onwards supports model inferencing and serving on AMD GPUs with ROCm.
At the moment AWQ quantization is not supported in ROCm, but SqueezeLLM quantization has been ported. At the moment AWQ quantization is not supported in ROCm, but SqueezeLLM quantization has been ported.
Data types currently supported in ROCm are FP16 and BF16. Data types currently supported in ROCm are FP16 and BF16.
...@@ -29,7 +29,7 @@ Installation options: ...@@ -29,7 +29,7 @@ Installation options:
.. code-block:: console .. code-block:: console
$ docker pull embeddedllminfo/vllm-rocm:vllm-v0.2.3 $ docker pull embeddedllminfo/vllm-rocm:vllm-v0.2.4
$ docker run -it \ $ docker run -it \
--network=host \ --network=host \
--group-add=video \ --group-add=video \
...@@ -70,12 +70,12 @@ You can build and install vLLM from source: ...@@ -70,12 +70,12 @@ You can build and install vLLM from source:
- ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention. - ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention.
- You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`) - You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`)
2. Setup `xformers==0.0.22.post7` without dependencies, and apply patches to adapt for ROCm flash attention 2. Setup `xformers==0.0.23` without dependencies, and apply patches to adapt for ROCm flash attention
.. code-block:: console .. code-block:: console
$ pip install xformers==0.0.22.post7 --no-deps $ pip install xformers==0.0.23 --no-deps
$ bash patch_xformers-0.0.22.post7.rocm.sh $ bash patch_xformers.rocm.sh
3. Build vLLM. 3. Build vLLM.
...@@ -127,12 +127,12 @@ Alternatively, if you plan to install vLLM-ROCm on a local machine or start from ...@@ -127,12 +127,12 @@ Alternatively, if you plan to install vLLM-ROCm on a local machine or start from
- ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention. - ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention.
- You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`) - You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`)
2. Setup `xformers==0.0.22.post7` without dependencies, and apply patches to adapt for ROCm flash attention 2. Setup `xformers==0.0.23` without dependencies, and apply patches to adapt for ROCm flash attention
.. code-block:: console .. code-block:: console
$ pip install xformers==0.0.22.post7 --no-deps $ pip install xformers==0.0.23 --no-deps
$ bash patch_xformers-0.0.22.post7.rocm.sh $ bash patch_xformers.rocm.sh
3. Build vLLM. 3. Build vLLM.
......
#!/bin/bash #!/bin/bash
set -e
XFORMERS_VERSION="0.0.23"
export XFORMERS_INSTALLED_VERSION=$(python -c 'import xformers; print(xformers.__version__)')
if [ "$XFORMERS_INSTALLED_VERSION" != "$XFORMERS_VERSION" ]; then
echo "ERROR: xformers version must be ${XFORMERS_VERSION}. ${XFORMERS_INSTALLED_VERSION} is installed"
exit 1
fi
export XFORMERS_FMHA_FLASH_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.flash.__file__)') export XFORMERS_FMHA_FLASH_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.flash.__file__)')
export XFORMERS_FMHA_COMMON_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.common.__file__)') export XFORMERS_FMHA_COMMON_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.common.__file__)')
echo $XFORMERS_FMHA_FLASH_PATH echo "XFORMERS_FMHA_FLASH_PATH = ${XFORMERS_FMHA_FLASH_PATH}"
echo $XFORMERS_FMHA_COMMON_PATH echo "XFORMERS_FMHA_COMMON_PATH = ${XFORMERS_FMHA_COMMON_PATH}"
if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch"; then if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-${XFORMERS_VERSION}.rocm.patch"; then
echo "Applying patch to ${XFORMERS_FMHA_FLASH_PATH}" echo "Applying patch to ${XFORMERS_FMHA_FLASH_PATH}"
patch -p0 $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch" patch -p0 $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-${XFORMERS_VERSION}.rocm.patch"
echo "Successfully patch ${XFORMERS_FMHA_FLASH_PATH}" echo "Successfully patch ${XFORMERS_FMHA_FLASH_PATH}"
else else
echo "${XFORMERS_FMHA_FLASH_PATH} was patched before" echo "${XFORMERS_FMHA_FLASH_PATH} was patched before"
fi fi
if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-0.0.22.post7.rocm.patch"; then if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-${XFORMERS_VERSION}.rocm.patch"; then
echo "Applying patch to ${XFORMERS_FMHA_COMMON_PATH}" echo "Applying patch to ${XFORMERS_FMHA_COMMON_PATH}"
patch -p0 $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-0.0.22.post7.rocm.patch" patch -p0 $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-${XFORMERS_VERSION}.rocm.patch"
echo "Successfully patch ${XFORMERS_FMHA_COMMON_PATH}" echo "Successfully patch ${XFORMERS_FMHA_COMMON_PATH}"
else else
echo "${XFORMERS_FMHA_COMMON_PATH} was patched before" echo "${XFORMERS_FMHA_COMMON_PATH} was patched before"
......
...@@ -8,7 +8,6 @@ pyarrow # Required for Ray data. ...@@ -8,7 +8,6 @@ pyarrow # Required for Ray data.
sentencepiece # Required for LLaMA tokenizer. sentencepiece # Required for LLaMA tokenizer.
numpy numpy
tokenizers>=0.15.0 tokenizers>=0.15.0
huggingface_hub<0.18,>=0.16.4
transformers >= 4.36.0 # Required for Mixtral. transformers >= 4.36.0 # Required for Mixtral.
fastapi fastapi
uvicorn[standard] uvicorn[standard]
......
--- /opt/conda/envs/py_3.10/lib/python3.10/site-packages/xformers/ops/fmha/flash.py 2023-11-29 03:17:03.930103539 +0000 --- flash_ori.py 2023-12-13 05:43:31.530752623 +0000
+++ flash.py 2023-11-28 16:14:25.206128903 +0000 +++ flash_patch.py 2023-12-13 06:00:45.962403104 +0000
@@ -31,39 +31,39 @@ @@ -36,44 +36,44 @@
FLASH_VERSION = "0.0.0" FLASH_VERSION = "0.0.0"
try: try:
...@@ -15,9 +15,12 @@ ...@@ -15,9 +15,12 @@
- from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention - from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
- -
- FLASH_VERSION = flash_attn.__version__ - FLASH_VERSION = flash_attn.__version__
- flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2]) - flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3])
- if flash_ver_parsed < (2, 3): - if (
- raise ImportError("Requires 2.3 for sliding window support") - flash_ver_parsed != (2, 3, 6)
- and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1"
- ):
- raise ImportError("Requires Flash attention 2.3.6 for varlen_fwd api")
+ #try: + #try:
+ # from ... import _C_flashattention # type: ignore[attr-defined] + # from ... import _C_flashattention # type: ignore[attr-defined]
+ # from ..._cpp_lib import _build_metadata + # from ..._cpp_lib import _build_metadata
...@@ -29,35 +32,41 @@ ...@@ -29,35 +32,41 @@
+ from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention + from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
+ +
+ FLASH_VERSION = flash_attn.__version__ + FLASH_VERSION = flash_attn.__version__
+ # flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2]) + # flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3])
+ # if flash_ver_parsed < (2, 3): + # if (
+ # raise ImportError("Requires 2.3 for sliding window support") + # flash_ver_parsed != (2, 3, 6)
+ # and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1"
+ # ):
+ # raise ImportError("Requires Flash attention 2.3.6 for varlen_fwd api")
# create library so that flash-attn goes through the PyTorch Dispatcher # create library so that flash-attn goes through the PyTorch Dispatcher
- _flash_lib = torch.library.Library("xformers_flash", "DEF") - _flash_lib = torch.library.Library("xformers_flash", "DEF")
+ #_flash_lib = torch.library.Library("xformers_flash", "DEF") -
- _flash_lib.define( - _flash_lib.define(
- "flash_fwd(Tensor query, Tensor key, Tensor value, " - "flash_fwd(Tensor query, Tensor key, Tensor value, "
- "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, " - "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? seqused_k, "
- "int max_seqlen_q, int max_seqlen_k, " - "int max_seqlen_q, int max_seqlen_k, "
- "float p, float softmax_scale, " - "float p, float softmax_scale, "
- "bool is_causal, int window_size, bool return_softmax) -> (Tensor, Tensor, Tensor)" - "bool is_causal, int window_left, "
- "int window_right, bool return_softmax) -> (Tensor, Tensor, Tensor)"
- ) - )
- + #_flash_lib = torch.library.Library("xformers_flash", "DEF")
- _flash_lib.define( - _flash_lib.define(
- "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, " - "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
- "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, " - "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
- "Tensor cu_seqlens_q, Tensor cu_seqlens_k, " - "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
- "int max_seqlen_q, int max_seqlen_k, " - "int max_seqlen_q, int max_seqlen_k, "
- "float p, float softmax_scale, bool is_causal, int window_size, Tensor rng_state) -> (Tensor, Tensor, Tensor)" - "float p, float softmax_scale, bool is_causal, "
- "int window_left, int window_right, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
- ) - )
+ #_flash_lib.define( + #_flash_lib.define(
+ # "flash_fwd(Tensor query, Tensor key, Tensor value, " + # "flash_fwd(Tensor query, Tensor key, Tensor value, "
+ # "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, " + # "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? seqused_k, "
+ # "int max_seqlen_q, int max_seqlen_k, " + # "int max_seqlen_q, int max_seqlen_k, "
+ # "float p, float softmax_scale, " + # "float p, float softmax_scale, "
+ # "bool is_causal, int window_size, bool return_softmax) -> (Tensor, Tensor, Tensor)" + # "bool is_causal, int window_left, "
+ # "int window_right, bool return_softmax) -> (Tensor, Tensor, Tensor)"
+ #) + #)
+ +
+ #_flash_lib.define( + #_flash_lib.define(
...@@ -65,52 +74,61 @@ ...@@ -65,52 +74,61 @@
+ # "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, " + # "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
+ # "Tensor cu_seqlens_q, Tensor cu_seqlens_k, " + # "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
+ # "int max_seqlen_q, int max_seqlen_k, " + # "int max_seqlen_q, int max_seqlen_k, "
+ # "float p, float softmax_scale, bool is_causal, int window_size, Tensor rng_state) -> (Tensor, Tensor, Tensor)" + # "float p, float softmax_scale, bool is_causal, "
+ # "int window_left, int window_right, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
+ #) + #)
def _flash_fwd( def _flash_fwd(
query, query,
@@ -98,8 +98,8 @@ @@ -111,8 +111,8 @@
p, p,
softmax_scale, softmax_scale,
is_causal, is_causal,
- window_size - 1, # window_size_left - window_left, # window_size_left
- -1, # window_size_right - window_right, # window_size_right
+ # window_size - 1, # window_size_left + # window_left, # window_size_left
+ # -1, # window_size_right + # window_right, # window_size_right
return_softmax, return_softmax,
None, # rng None, # rng
) )
@@ -127,8 +127,8 @@ @@ -134,15 +134,15 @@
out,
cu_seq_lens_q,
cu_seq_lens_k,
- seqused_k,
+ # seqused_k,
max_seq_len_q,
max_seq_len_k,
p,
softmax_scale, softmax_scale,
False, False,
is_causal, is_causal,
- window_size - 1, # window_size_left - window_left,
- -1, # window_size_right - window_right,
+ # window_size - 1, # window_size_left + # window_left,
+ # -1, # window_size_right + # window_right,
return_softmax, return_softmax,
None, None,
) )
@@ -169,8 +169,8 @@ @@ -184,8 +184,8 @@
p, p,
softmax_scale, softmax_scale,
is_causal, is_causal,
- window_size - 1, # window_size_left - window_left,
- -1, # window_size_right - window_right,
+ # window_size - 1, # window_size_left + # window_left,
+ # -1, # window_size_right + # window_right,
None, None,
rng_state, rng_state,
) )
@@ -193,15 +193,15 @@ @@ -208,15 +208,15 @@
softmax_scale, softmax_scale,
False, # zero_tensors False, # zero_tensors
is_causal, is_causal,
- window_size - 1, # window_size_left - window_left,
- -1, # window_size_right - window_right,
+ # window_size - 1, # window_size_left + # window_left,
+ # -1, # window_size_right + # window_right,
None, None,
rng_state, rng_state,
) )
...@@ -123,7 +141,7 @@ ...@@ -123,7 +141,7 @@
except ImportError: except ImportError:
pass pass
@@ -348,7 +348,7 @@ @@ -400,7 +400,7 @@
implementation. implementation.
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment