"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "a93a53f8a1302c992ad185e70c6ab4affe43c4d7"
Unverified Commit d9c7db18 authored by Micah Williamson's avatar Micah Williamson Committed by GitHub
Browse files

[ROCm][CI] Pin test_hybrid test to TRITON_ATTN on ROCm (#38381)


Signed-off-by: default avatarMicah Williamson <micah.williamson@amd.com>
parent 12701e8a
...@@ -57,6 +57,8 @@ FP32_STATE_MODELS = [ ...@@ -57,6 +57,8 @@ FP32_STATE_MODELS = [
# Avoid OOM # Avoid OOM
MAX_NUM_SEQS = 4 MAX_NUM_SEQS = 4
ATTN_BACKEND = "TRITON_ATTN" if current_platform.is_rocm() else "auto"
@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS) @pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
@pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("max_tokens", [64])
...@@ -82,7 +84,9 @@ def test_models( ...@@ -82,7 +84,9 @@ def test_models(
example_prompts, max_tokens, num_logprobs example_prompts, max_tokens, num_logprobs
) )
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: with vllm_runner(
model, max_num_seqs=MAX_NUM_SEQS, attention_backend=ATTN_BACKEND
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs( vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs example_prompts, max_tokens, num_logprobs
) )
...@@ -157,6 +161,7 @@ def test_chunked_prefill_with_parallel_sampling( ...@@ -157,6 +161,7 @@ def test_chunked_prefill_with_parallel_sampling(
# forces prefill chunks with decoding # forces prefill chunks with decoding
max_num_batched_tokens=MAX_NUM_SEQS * 3, max_num_batched_tokens=MAX_NUM_SEQS * 3,
max_num_seqs=MAX_NUM_SEQS, max_num_seqs=MAX_NUM_SEQS,
attention_backend=ATTN_BACKEND,
) as vllm_model: ) as vllm_model:
vllm_model.generate(example_prompts, sampling_params) vllm_model.generate(example_prompts, sampling_params)
...@@ -301,7 +306,9 @@ def test_full_cuda_graph( ...@@ -301,7 +306,9 @@ def test_full_cuda_graph(
example_prompts, max_tokens, num_logprobs example_prompts, max_tokens, num_logprobs
) )
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: with vllm_runner(
model, max_num_seqs=MAX_NUM_SEQS, attention_backend=ATTN_BACKEND
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs( vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs example_prompts, max_tokens, num_logprobs
) )
...@@ -370,6 +377,7 @@ def _get_vllm_runner_params( ...@@ -370,6 +377,7 @@ def _get_vllm_runner_params(
"max_model_len": max_model_len, "max_model_len": max_model_len,
"tensor_parallel_size": tensor_parallel_size, "tensor_parallel_size": tensor_parallel_size,
"gpu_memory_utilization": 0.4, "gpu_memory_utilization": 0.4,
"attention_backend": ATTN_BACKEND,
} }
...@@ -844,6 +852,7 @@ def test_apc_common_prefix_same_batch( ...@@ -844,6 +852,7 @@ def test_apc_common_prefix_same_batch(
mamba_block_size=16, mamba_block_size=16,
enable_prefix_caching=True, enable_prefix_caching=True,
seed=42, seed=42,
attention_backend=ATTN_BACKEND,
) )
prompts = [ prompts = [
"hello what is one plus one what is one plus one what is one plus one the answer is", # noqa: E501 "hello what is one plus one what is one plus one what is one plus one the answer is", # noqa: E501
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment