Unverified Commit 6c64c41b authored by Micah Williamson's avatar Micah Williamson Committed by GitHub
Browse files

[ROCm][CI] Force max_num_seqs=1 on ROCm In test_sharded_state_loader to reduce flakiness (#33277)


Signed-off-by: default avatarMicah Williamson <micah.williamson@amd.com>
parent a2ef06e1
...@@ -13,6 +13,7 @@ from huggingface_hub import snapshot_download ...@@ -13,6 +13,7 @@ from huggingface_hub import snapshot_download
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.model_executor.model_loader import ShardedStateLoader from vllm.model_executor.model_loader import ShardedStateLoader
from vllm.platforms import current_platform
prompts = [ prompts = [
"Hello, my name is", "Hello, my name is",
...@@ -95,6 +96,10 @@ def test_sharded_state_loader( ...@@ -95,6 +96,10 @@ def test_sharded_state_loader(
input_dir = llama_3p2_1b_files input_dir = llama_3p2_1b_files
ctx = mp.get_context("spawn") ctx = mp.get_context("spawn")
platform_args = {}
if current_platform.is_rocm():
platform_args["max_num_seqs"] = 1
# Run in separate processes for memory & CUDA isolation # Run in separate processes for memory & CUDA isolation
with TemporaryDirectory() as output_dir: with TemporaryDirectory() as output_dir:
p = ctx.Process( p = ctx.Process(
...@@ -104,6 +109,7 @@ def test_sharded_state_loader( ...@@ -104,6 +109,7 @@ def test_sharded_state_loader(
tensor_parallel_size=tp_size, tensor_parallel_size=tp_size,
gpu_memory_utilization=gpu_memory_utilization, gpu_memory_utilization=gpu_memory_utilization,
enforce_eager=True, enforce_eager=True,
**platform_args,
), ),
) )
p.start() p.start()
...@@ -118,6 +124,7 @@ def test_sharded_state_loader( ...@@ -118,6 +124,7 @@ def test_sharded_state_loader(
enable_lora=enable_lora, enable_lora=enable_lora,
gpu_memory_utilization=gpu_memory_utilization, gpu_memory_utilization=gpu_memory_utilization,
tensor_parallel_size=tp_size, tensor_parallel_size=tp_size,
**platform_args,
), ),
) )
p.start() p.start()
...@@ -141,6 +148,7 @@ def test_sharded_state_loader( ...@@ -141,6 +148,7 @@ def test_sharded_state_loader(
gpu_memory_utilization=gpu_memory_utilization, gpu_memory_utilization=gpu_memory_utilization,
tensor_parallel_size=tp_size, tensor_parallel_size=tp_size,
load_format="sharded_state", load_format="sharded_state",
**platform_args,
), ),
) )
p.start() p.start()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment