Unverified Commit 23fea871 authored by TJian's avatar TJian Committed by GitHub
Browse files

[Bugfix] Fix try-catch conditions to import correct Flash Attention Backend in Draft Model (#9101)

parent f4dd830e
...@@ -6,11 +6,16 @@ from vllm.forward_context import set_forward_context ...@@ -6,11 +6,16 @@ from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
try: try:
from vllm.attention.backends.flash_attn import FlashAttentionMetadata try:
except ModuleNotFoundError: from vllm.attention.backends.flash_attn import FlashAttentionMetadata
# vllm_flash_attn is not installed, use the identical ROCm FA metadata except (ModuleNotFoundError, ImportError):
from vllm.attention.backends.rocm_flash_attn import ( # vllm_flash_attn is not installed, try the ROCm FA metadata
ROCmFlashAttentionMetadata as FlashAttentionMetadata) from vllm.attention.backends.rocm_flash_attn import (
ROCmFlashAttentionMetadata as FlashAttentionMetadata)
except (ModuleNotFoundError, ImportError) as err:
raise RuntimeError(
"Draft model speculative decoding currently only supports"
"CUDA and ROCm flash attention backend.") from err
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig, ModelConfig, ObservabilityConfig, ParallelConfig,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment