Unverified Commit 2cc5affc authored by Concurrensee's avatar Concurrensee Committed by GitHub
Browse files

[ROCM][CI] Fix AMD Examples Test Group (#30276)


Signed-off-by: default avatarYida Wu <yida.wu@amd.com>
Signed-off-by: default avatarYida <yida.wu@amd.com>
parent a00d8897
...@@ -435,7 +435,7 @@ steps: ...@@ -435,7 +435,7 @@ steps:
- label: Examples Test # 30min - label: Examples Test # 30min
timeout_in_minutes: 45 timeout_in_minutes: 45
mirror_hardwares: [amdexperimental] mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_1 agent_pool: mi325_1
# grade: Blocking # grade: Blocking
working_dir: "/vllm-workspace/examples" working_dir: "/vllm-workspace/examples"
...@@ -455,7 +455,6 @@ steps: ...@@ -455,7 +455,6 @@ steps:
# for multi-modal models # for multi-modal models
- python3 offline_inference/audio_language.py --seed 0 - python3 offline_inference/audio_language.py --seed 0
- python3 offline_inference/vision_language.py --seed 0 - python3 offline_inference/vision_language.py --seed 0
- python3 offline_inference/vision_language_pooling.py --seed 0
- python3 offline_inference/vision_language_multi_image.py --seed 0 - python3 offline_inference/vision_language_multi_image.py --seed 0
- python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0 - python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0
# for pooling models # for pooling models
......
...@@ -4,6 +4,9 @@ ...@@ -4,6 +4,9 @@
from argparse import Namespace from argparse import Namespace
from vllm import LLM, EngineArgs from vllm import LLM, EngineArgs
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import AttentionConfig
from vllm.platforms import current_platform
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
...@@ -20,6 +23,11 @@ def parse_args(): ...@@ -20,6 +23,11 @@ def parse_args():
def main(args: Namespace): def main(args: Namespace):
if current_platform.is_rocm():
args.attention_config = AttentionConfig(
backend=AttentionBackendEnum.FLEX_ATTENTION
)
# Sample prompts. # Sample prompts.
prompts = [ prompts = [
"Hello, my name is", "Hello, my name is",
......
...@@ -4,6 +4,9 @@ ...@@ -4,6 +4,9 @@
from argparse import Namespace from argparse import Namespace
from vllm import LLM, EngineArgs from vllm import LLM, EngineArgs
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import AttentionConfig
from vllm.platforms import current_platform
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
...@@ -20,6 +23,11 @@ def parse_args(): ...@@ -20,6 +23,11 @@ def parse_args():
def main(args: Namespace): def main(args: Namespace):
if current_platform.is_rocm():
args.attention_config = AttentionConfig(
backend=AttentionBackendEnum.FLEX_ATTENTION
)
# Sample prompts. # Sample prompts.
text_1 = "What is the capital of France?" text_1 = "What is the capital of France?"
texts_2 = [ texts_2 = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment