Unverified Commit 23fea871 authored by TJian's avatar TJian Committed by GitHub
Browse files

[Bugfix] Fix try-catch conditions to import correct Flash Attention Backend in Draft Model (#9101)

parent f4dd830e
...@@ -6,11 +6,16 @@ from vllm.forward_context import set_forward_context ...@@ -6,11 +6,16 @@ from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
try: try:
try:
from vllm.attention.backends.flash_attn import FlashAttentionMetadata from vllm.attention.backends.flash_attn import FlashAttentionMetadata
except ModuleNotFoundError: except (ModuleNotFoundError, ImportError):
# vllm_flash_attn is not installed, use the identical ROCm FA metadata # vllm_flash_attn is not installed, try the ROCm FA metadata
from vllm.attention.backends.rocm_flash_attn import ( from vllm.attention.backends.rocm_flash_attn import (
ROCmFlashAttentionMetadata as FlashAttentionMetadata) ROCmFlashAttentionMetadata as FlashAttentionMetadata)
except (ModuleNotFoundError, ImportError) as err:
raise RuntimeError(
"Draft model speculative decoding currently only supports"
"CUDA and ROCm flash attention backend.") from err
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig, ModelConfig, ObservabilityConfig, ParallelConfig,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment