"vscode:/vscode.git/clone" did not exist on "917fdae5b2eccf0e7b6f2d4ae67132d13d13580c"
Unverified Commit 06d6c5fe authored by Matt Wong's avatar Matt Wong Committed by GitHub
Browse files

[Bugfix][CI/Build][Hardware][AMD] Fix AMD tests, add HF cache, update CK FA,...

[Bugfix][CI/Build][Hardware][AMD] Fix AMD tests, add HF cache, update CK FA, add partially supported model notes (#6543)
parent 683e3cb9
...@@ -66,11 +66,18 @@ trap remove_docker_container EXIT ...@@ -66,11 +66,18 @@ trap remove_docker_container EXIT
echo "--- Running container" echo "--- Running container"
HF_CACHE="$(realpath ~)/huggingface"
mkdir -p ${HF_CACHE}
HF_MOUNT="/root/.cache/huggingface"
docker run \ docker run \
--device /dev/kfd --device /dev/dri \ --device /dev/kfd --device /dev/dri \
--network host \ --network host \
--shm-size=16gb \
--rm \ --rm \
-e HF_TOKEN \ -e HF_TOKEN \
-v ${HF_CACHE}:${HF_MOUNT} \
-e HF_HOME=${HF_MOUNT} \
--name ${container_name} \ --name ${container_name} \
${image_name} \ ${image_name} \
/bin/bash -c "${@}" /bin/bash -c "${@}"
......
...@@ -44,7 +44,8 @@ steps: ...@@ -44,7 +44,8 @@ steps:
mirror_hardwares: [amd] mirror_hardwares: [amd]
fast_check: true fast_check: true
commands: commands:
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl # This flashinfer installation will fail on AMD ROCm, so it is set as optional.
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl || true
- pytest -v -s basic_correctness/test_basic_correctness.py - pytest -v -s basic_correctness/test_basic_correctness.py
- pytest -v -s basic_correctness/test_cpu_offload.py - pytest -v -s basic_correctness/test_cpu_offload.py
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
......
...@@ -33,7 +33,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11 ...@@ -33,7 +33,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11
# versions are derived from Dockerfile.rocm # versions are derived from Dockerfile.rocm
# #
set(TORCH_SUPPORTED_VERSION_CUDA "2.3.1") set(TORCH_SUPPORTED_VERSION_CUDA "2.3.1")
set(TORCH_SUPPORTED_VERSION_ROCM "2.4.0") set(TORCH_SUPPORTED_VERSION_ROCM "2.5.0")
# #
# Try to find python package with an executable that exactly matches # Try to find python package with an executable that exactly matches
...@@ -101,7 +101,7 @@ elseif(HIP_FOUND) ...@@ -101,7 +101,7 @@ elseif(HIP_FOUND)
# ROCm 5.X and 6.X # ROCm 5.X and 6.X
if (ROCM_VERSION_DEV_MAJOR GREATER_EQUAL 5 AND if (ROCM_VERSION_DEV_MAJOR GREATER_EQUAL 5 AND
NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM}) NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM})
message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM} " message(WARNING "Pytorch version >= ${TORCH_SUPPORTED_VERSION_ROCM} "
"expected for ROCm build, saw ${Torch_VERSION} instead.") "expected for ROCm build, saw ${Torch_VERSION} instead.")
endif() endif()
else() else()
......
...@@ -4,18 +4,21 @@ ARG BASE_IMAGE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging" ...@@ -4,18 +4,21 @@ ARG BASE_IMAGE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging"
# Default ROCm ARCHes to build vLLM for. # Default ROCm ARCHes to build vLLM for.
ARG PYTORCH_ROCM_ARCH="gfx908;gfx90a;gfx942;gfx1100" ARG PYTORCH_ROCM_ARCH="gfx908;gfx90a;gfx942;gfx1100"
# Whether to build CK-based flash-attention # Whether to install CK-based flash-attention
# If 0, will not build flash attention # If 0, will not install flash-attention
# This is useful for gfx target where flash-attention is not supported
# (i.e. those that do not appear in `FA_GFX_ARCHS`)
# Triton FA is used by default on ROCm now so this is unnecessary.
ARG BUILD_FA="1" ARG BUILD_FA="1"
# If `TRY_FA_WHEEL=1`, we will try installing flash-attention from `FA_WHEEL_URL`
# If this succeeds, we use the downloaded wheel and skip building flash-attention.
# Otherwise, ROCm flash-attention from `FA_BRANCH` will be built for the
# architectures specified in `FA_GFX_ARCHS`
ARG TRY_FA_WHEEL="1"
ARG FA_WHEEL_URL="https://github.com/ROCm/flash-attention/releases/download/v2.5.9post1-cktile-vllm/flash_attn-2.5.9.post1-cp39-cp39-linux_x86_64.whl"
ARG FA_GFX_ARCHS="gfx90a;gfx942" ARG FA_GFX_ARCHS="gfx90a;gfx942"
ARG FA_BRANCH="ae7928c" ARG FA_BRANCH="23a2b1c2"
# Whether to build triton on rocm # Whether to build triton on rocm
ARG BUILD_TRITON="1" ARG BUILD_TRITON="1"
ARG TRITON_BRANCH="0ef1848" ARG TRITON_BRANCH="e0fc12c"
### Base image build stage ### Base image build stage
FROM $BASE_IMAGE AS base FROM $BASE_IMAGE AS base
...@@ -43,15 +46,15 @@ RUN apt-get update && apt-get install -y \ ...@@ -43,15 +46,15 @@ RUN apt-get update && apt-get install -y \
ARG APP_MOUNT=/vllm-workspace ARG APP_MOUNT=/vllm-workspace
WORKDIR ${APP_MOUNT} WORKDIR ${APP_MOUNT}
RUN pip install --upgrade pip RUN python3 -m pip install --upgrade pip
# Remove sccache so it doesn't interfere with ccache # Remove sccache so it doesn't interfere with ccache
# TODO: implement sccache support across components # TODO: implement sccache support across components
RUN apt-get purge -y sccache; pip uninstall -y sccache; rm -f "$(which sccache)" RUN apt-get purge -y sccache; python3 -m pip uninstall -y sccache; rm -f "$(which sccache)"
# Install torch == 2.5.0 on ROCm # Install torch == 2.5.0 on ROCm
RUN case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \ RUN case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
*"rocm-6.1"*) \ *"rocm-6.1"*) \
pip uninstall -y torch torchaudio torchvision \ python3 -m pip uninstall -y torch torchaudio torchvision \
&& pip install --no-cache-dir --pre \ && python3 -m pip install --no-cache-dir --pre \
torch==2.5.0.dev20240710 torchaudio==2.4.0.dev20240710 \ torch==2.5.0.dev20240710 torchaudio==2.4.0.dev20240710 \
torchvision==0.20.0.dev20240710 \ torchvision==0.20.0.dev20240710 \
--index-url https://download.pytorch.org/whl/nightly/rocm6.1;; \ --index-url https://download.pytorch.org/whl/nightly/rocm6.1;; \
...@@ -70,24 +73,31 @@ ENV CCACHE_DIR=/root/.cache/ccache ...@@ -70,24 +73,31 @@ ENV CCACHE_DIR=/root/.cache/ccache
FROM base AS build_amdsmi FROM base AS build_amdsmi
# Build amdsmi wheel always # Build amdsmi wheel always
RUN cd /opt/rocm/share/amd_smi \ RUN cd /opt/rocm/share/amd_smi \
&& pip wheel . --wheel-dir=/install && python3 -m pip wheel . --wheel-dir=/install
### Flash-Attention wheel build stage ### Flash-Attention wheel build stage
FROM base AS build_fa FROM base AS build_fa
ARG BUILD_FA ARG BUILD_FA
ARG TRY_FA_WHEEL
ARG FA_WHEEL_URL
ARG FA_GFX_ARCHS ARG FA_GFX_ARCHS
ARG FA_BRANCH ARG FA_BRANCH
# Build ROCm flash-attention wheel if `BUILD_FA = 1` # Build ROCm flash-attention wheel if `BUILD_FA = 1`
RUN --mount=type=cache,target=${CCACHE_DIR} \ RUN --mount=type=cache,target=${CCACHE_DIR} \
if [ "$BUILD_FA" = "1" ]; then \ if [ "$BUILD_FA" = "1" ]; then \
mkdir -p libs \ if [ "${TRY_FA_WHEEL}" = "1" ] && python3 -m pip install "${FA_WHEEL_URL}"; then \
&& cd libs \ # If a suitable wheel exists, we download it instead of building FA
&& git clone https://github.com/ROCm/flash-attention.git \ mkdir -p /install && wget -N "${FA_WHEEL_URL}" -P /install; \
&& cd flash-attention \ else \
&& git checkout "${FA_BRANCH}" \ mkdir -p libs \
&& git submodule update --init \ && cd libs \
&& GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \ && git clone https://github.com/ROCm/flash-attention.git \
&& cd flash-attention \
&& git checkout "${FA_BRANCH}" \
&& git submodule update --init \
&& GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \
fi; \
# Create an empty directory otherwise as later build stages expect one # Create an empty directory otherwise as later build stages expect one
else mkdir -p /install; \ else mkdir -p /install; \
fi fi
...@@ -126,7 +136,7 @@ RUN case "$(which python3)" in \ ...@@ -126,7 +136,7 @@ RUN case "$(which python3)" in \
# Package upgrades for useful functionality or to avoid dependency issues # Package upgrades for useful functionality or to avoid dependency issues
RUN --mount=type=cache,target=/root/.cache/pip \ RUN --mount=type=cache,target=/root/.cache/pip \
pip install --upgrade numba scipy huggingface-hub[cli] python3 -m pip install --upgrade numba scipy huggingface-hub[cli]
# Make sure punica kernels are built (for LoRA) # Make sure punica kernels are built (for LoRA)
ENV VLLM_INSTALL_PUNICA_KERNELS=1 ENV VLLM_INSTALL_PUNICA_KERNELS=1
...@@ -137,7 +147,7 @@ ENV TOKENIZERS_PARALLELISM=false ...@@ -137,7 +147,7 @@ ENV TOKENIZERS_PARALLELISM=false
RUN --mount=type=cache,target=${CCACHE_DIR} \ RUN --mount=type=cache,target=${CCACHE_DIR} \
--mount=type=cache,target=/root/.cache/pip \ --mount=type=cache,target=/root/.cache/pip \
pip install -U -r requirements-rocm.txt \ python3 -m pip install -Ur requirements-rocm.txt \
&& case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \ && case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
*"rocm-6.1"*) \ *"rocm-6.1"*) \
# Bring in upgrades to HIP graph earlier than ROCm 6.2 for vLLM # Bring in upgrades to HIP graph earlier than ROCm 6.2 for vLLM
...@@ -153,7 +163,7 @@ RUN --mount=type=bind,from=build_amdsmi,src=/install,target=/install \ ...@@ -153,7 +163,7 @@ RUN --mount=type=bind,from=build_amdsmi,src=/install,target=/install \
mkdir -p libs \ mkdir -p libs \
&& cp /install/*.whl libs \ && cp /install/*.whl libs \
# Preemptively uninstall to avoid same-version no-installs # Preemptively uninstall to avoid same-version no-installs
&& pip uninstall -y amdsmi; && python3 -m pip uninstall -y amdsmi;
# Copy triton wheel(s) into final image if they were built # Copy triton wheel(s) into final image if they were built
RUN --mount=type=bind,from=build_triton,src=/install,target=/install \ RUN --mount=type=bind,from=build_triton,src=/install,target=/install \
...@@ -161,7 +171,7 @@ RUN --mount=type=bind,from=build_triton,src=/install,target=/install \ ...@@ -161,7 +171,7 @@ RUN --mount=type=bind,from=build_triton,src=/install,target=/install \
&& if ls /install/*.whl; then \ && if ls /install/*.whl; then \
cp /install/*.whl libs \ cp /install/*.whl libs \
# Preemptively uninstall to avoid same-version no-installs # Preemptively uninstall to avoid same-version no-installs
&& pip uninstall -y triton; fi && python3 -m pip uninstall -y triton; fi
# Copy flash-attn wheel(s) into final image if they were built # Copy flash-attn wheel(s) into final image if they were built
RUN --mount=type=bind,from=build_fa,src=/install,target=/install \ RUN --mount=type=bind,from=build_fa,src=/install,target=/install \
...@@ -169,11 +179,11 @@ RUN --mount=type=bind,from=build_fa,src=/install,target=/install \ ...@@ -169,11 +179,11 @@ RUN --mount=type=bind,from=build_fa,src=/install,target=/install \
&& if ls /install/*.whl; then \ && if ls /install/*.whl; then \
cp /install/*.whl libs \ cp /install/*.whl libs \
# Preemptively uninstall to avoid same-version no-installs # Preemptively uninstall to avoid same-version no-installs
&& pip uninstall -y flash-attn; fi && python3 -m pip uninstall -y flash-attn; fi
# Install wheels that were built to the final image # Install wheels that were built to the final image
RUN --mount=type=cache,target=/root/.cache/pip \ RUN --mount=type=cache,target=/root/.cache/pip \
if ls libs/*.whl; then \ if ls libs/*.whl; then \
pip install libs/*.whl; fi python3 -m pip install libs/*.whl; fi
CMD ["/bin/bash"] CMD ["/bin/bash"]
...@@ -90,12 +90,12 @@ Alternatively, you can install PyTorch using PyTorch wheels. You can check PyTor ...@@ -90,12 +90,12 @@ Alternatively, you can install PyTorch using PyTorch wheels. You can check PyTor
Install ROCm's Triton flash attention (the default triton-mlir branch) following the instructions from `ROCm/triton <https://github.com/ROCm/triton/blob/triton-mlir/README.md>`_ Install ROCm's Triton flash attention (the default triton-mlir branch) following the instructions from `ROCm/triton <https://github.com/ROCm/triton/blob/triton-mlir/README.md>`_
2. Optionally, if you choose to use CK flash attention, you can install `flash attention for ROCm <https://github.com/ROCm/flash-attention/tree/flash_attention_for_rocm>`_ 2. Optionally, if you choose to use CK flash attention, you can install `flash attention for ROCm <https://github.com/ROCm/flash-attention/tree/ck_tile>`_
Install ROCm's flash attention (v2.0.4) following the instructions from `ROCm/flash-attention <https://github.com/ROCm/flash-attention/tree/flash_attention_for_rocm#amd-gpurocm-support>`_ Install ROCm's flash attention (v2.5.9.post1) following the instructions from `ROCm/flash-attention <https://github.com/ROCm/flash-attention/tree/ck_tile#amd-gpurocm-support>`_
Alternatively, wheels intended for vLLM use can be accessed under the releases.
.. note:: .. note::
- ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention.
- You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`) - You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`)
3. Build vLLM. 3. Build vLLM.
...@@ -110,5 +110,6 @@ Install ROCm's flash attention (v2.0.4) following the instructions from `ROCm/fl ...@@ -110,5 +110,6 @@ Install ROCm's flash attention (v2.0.4) following the instructions from `ROCm/fl
.. tip:: .. tip::
- Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm up step before collecting perf numbers. - Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm up step before collecting perf numbers.
- Triton flash attention does not currently support sliding window attention. If using half precision, please use CK flash-attention for sliding window support.
- To use CK flash-attention or PyTorch naive attention, please use this flag ``export VLLM_USE_TRITON_FLASH_ATTN=0`` to turn off triton flash attention. - To use CK flash-attention or PyTorch naive attention, please use this flag ``export VLLM_USE_TRITON_FLASH_ATTN=0`` to turn off triton flash attention.
- The ROCm version of PyTorch, ideally, should match the ROCm driver version. - The ROCm version of PyTorch, ideally, should match the ROCm driver version.
...@@ -2,5 +2,9 @@ ...@@ -2,5 +2,9 @@
-r requirements-common.txt -r requirements-common.txt
# Dependencies for AMD GPUs # Dependencies for AMD GPUs
awscli
boto3
botocore
ray >= 2.10.0 ray >= 2.10.0
peft
pytest-asyncio pytest-asyncio
from vllm.utils import is_hip
from ..utils import compare_two_settings from ..utils import compare_two_settings
def test_cpu_offload(): def test_cpu_offload():
compare_two_settings("meta-llama/Llama-2-7b-hf", [], compare_two_settings("meta-llama/Llama-2-7b-hf", [],
["--cpu-offload-gb", "4"]) ["--cpu-offload-gb", "4"])
compare_two_settings("nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t", if not is_hip():
[], ["--cpu-offload-gb", "1"]) # compressed-tensors quantization is currently not supported in ROCm.
compare_two_settings(
"nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t", [],
["--cpu-offload-gb", "1"])
import os
from typing import List, Optional, Tuple, Type from typing import List, Optional, Tuple, Type
import pytest import pytest
...@@ -5,6 +6,7 @@ from transformers import AutoTokenizer ...@@ -5,6 +6,7 @@ from transformers import AutoTokenizer
from vllm.multimodal.utils import rescale_image_size from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs from vllm.sequence import SampleLogprobs
from vllm.utils import is_hip
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from .utils import check_logprobs_close from .utils import check_logprobs_close
...@@ -22,6 +24,12 @@ IMAGE_TOKEN_ID = 257152 ...@@ -22,6 +24,12 @@ IMAGE_TOKEN_ID = 257152
models = ["google/paligemma-3b-mix-224"] models = ["google/paligemma-3b-mix-224"]
# ROCm Triton FA can run into compilation issues with these models due to,
# excessive use of shared memory. Use other backends in the meantime.
# FIXME (mattwong, gshtrasb, hongxiayan)
if is_hip():
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"
def vllm_to_hf_output(vllm_output: Tuple[List[int], str, def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
Optional[SampleLogprobs]], Optional[SampleLogprobs]],
...@@ -130,7 +138,15 @@ def run_test( ...@@ -130,7 +138,15 @@ def run_test(
[0.25, 0.5, 1.0], [0.25, 0.5, 1.0],
], ],
) )
@pytest.mark.parametrize("dtype", ["float", "half"]) @pytest.mark.parametrize("dtype", [
pytest.param(
"float",
marks=pytest.mark.skipif(
is_hip(),
reason=
"ROCm FA does not yet fully support 32-bit precision on PaliGemma")
), "half"
])
@pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
......
import os
import re import re
from typing import List, Optional, Tuple, Type from typing import List, Optional, Tuple, Type
...@@ -6,7 +7,7 @@ from transformers import AutoTokenizer ...@@ -6,7 +7,7 @@ from transformers import AutoTokenizer
from vllm.multimodal.utils import rescale_image_size from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs from vllm.sequence import SampleLogprobs
from vllm.utils import is_cpu from vllm.utils import is_cpu, is_hip
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from .utils import check_logprobs_close from .utils import check_logprobs_close
...@@ -47,6 +48,12 @@ target_dtype = "half" ...@@ -47,6 +48,12 @@ target_dtype = "half"
if is_cpu(): if is_cpu():
target_dtype = "bfloat16" target_dtype = "bfloat16"
# ROCm Triton FA can run into shared memory issues with these models,
# use other backends in the meantime
# FIXME (mattwong, gshtrasb, hongxiayan)
if is_hip():
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"
def run_test( def run_test(
hf_runner: Type[HfRunner], hf_runner: Type[HfRunner],
......
...@@ -275,6 +275,12 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -275,6 +275,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
triton_attention) triton_attention)
self.attn_func = triton_attention self.attn_func = triton_attention
logger.debug("Using Triton FA in ROCmBackend") logger.debug("Using Triton FA in ROCmBackend")
if self.sliding_window != (-1, -1):
logger.warning("ROCm Triton FA does not currently support "
"sliding window attention. If using half "
"precision, please try using the ROCm CK "
"FA backend instead by setting the env var "
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
else: else:
# if not using triton, navi3x/navi21/navi10 do not use flash-attn # if not using triton, navi3x/navi21/navi10 do not use flash-attn
# either # either
...@@ -434,6 +440,8 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -434,6 +440,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
max_seqlen_k=prefill_meta.max_prefill_seq_len, max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
) )
# common code for prefill # common code for prefill
......
...@@ -87,13 +87,24 @@ _ROCM_UNSUPPORTED_MODELS: List[str] = [] ...@@ -87,13 +87,24 @@ _ROCM_UNSUPPORTED_MODELS: List[str] = []
# Models partially supported by ROCm. # Models partially supported by ROCm.
# Architecture -> Reason. # Architecture -> Reason.
_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
"Triton flash attention. For half-precision SWA support, "
"please use CK flash attention by setting "
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = { _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
"Qwen2ForCausalLM": "Qwen2ForCausalLM":
"Sliding window attention is not yet supported in ROCm's flash attention", _ROCM_SWA_REASON,
"MistralForCausalLM": "MistralForCausalLM":
"Sliding window attention is not yet supported in ROCm's flash attention", _ROCM_SWA_REASON,
"MixtralForCausalLM": "MixtralForCausalLM":
"Sliding window attention is not yet supported in ROCm's flash attention", _ROCM_SWA_REASON,
"PaliGemmaForConditionalGeneration":
("ROCm flash attention does not yet "
"fully support 32-bit precision on PaliGemma"),
"Phi3VForCausalLM":
("ROCm Triton flash attention may run into compilation errors due to "
"excessive use of shared memory. If this happens, disable Triton FA "
"by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
} }
......
...@@ -3,7 +3,14 @@ from typing import List, Optional ...@@ -3,7 +3,14 @@ from typing import List, Optional
import torch import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
try:
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
except ModuleNotFoundError:
# vllm_flash_attn is not installed, use the identical ROCm FA metadata
from vllm.attention.backends.rocm_flash_attn import (
ROCmFlashAttentionMetadata as FlashAttentionMetadata)
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig, ModelConfig, MultiModalConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig) PromptAdapterConfig, SchedulerConfig)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment