Unverified Commit 29c9cb80 authored by Jiangyun Zhu's avatar Jiangyun Zhu Committed by GitHub
Browse files

[CI] Add tests for cudagraph (#27391)


Signed-off-by: default avatarzjy0516 <riverclouds.zhu@qq.com>
parent 83f478bb
...@@ -435,6 +435,18 @@ steps: ...@@ -435,6 +435,18 @@ steps:
- pytest -v -s compile/test_full_graph.py - pytest -v -s compile/test_full_graph.py
- pytest -v -s compile/test_fusions_e2e.py - pytest -v -s compile/test_fusions_e2e.py
- label: Cudagraph test
timeout_in_minutes: 20
mirror_hardwares: [amdexperimental]
source_file_dependencies:
- tests/v1/cudagraph
- vllm/v1/cudagraph_dispatcher.py
- vllm/config/compilation.py
- vllm/compilation
commands:
- pytest -v -s v1/cudagraph/test_cudagraph_dispatch.py
- pytest -v -s v1/cudagraph/test_cudagraph_mode.py
- label: Kernels Core Operation Test # 48min - label: Kernels Core Operation Test # 48min
timeout_in_minutes: 75 timeout_in_minutes: 75
mirror_hardwares: [amdexperimental] mirror_hardwares: [amdexperimental]
......
...@@ -1111,6 +1111,11 @@ def spawn_new_process_for_each_test(f: Callable[_P, None]) -> Callable[_P, None] ...@@ -1111,6 +1111,11 @@ def spawn_new_process_for_each_test(f: Callable[_P, None]) -> Callable[_P, None]
# `cloudpickle` allows pickling complex functions directly # `cloudpickle` allows pickling complex functions directly
input_bytes = cloudpickle.dumps((f, output_filepath)) input_bytes = cloudpickle.dumps((f, output_filepath))
repo_root = str(VLLM_PATH.resolve())
env = dict(env or os.environ)
env["PYTHONPATH"] = repo_root + os.pathsep + env.get("PYTHONPATH", "")
cmd = [sys.executable, "-m", f"{module_name}"] cmd = [sys.executable, "-m", f"{module_name}"]
returned = subprocess.run( returned = subprocess.run(
......
...@@ -34,13 +34,16 @@ class SimpleMLP(nn.Module): ...@@ -34,13 +34,16 @@ class SimpleMLP(nn.Module):
def _create_vllm_config( def _create_vllm_config(
compilation_config: CompilationConfig, max_num_seqs: int = 8 compilation_config: CompilationConfig,
max_num_seqs: int = 8,
lora_config: bool = False,
) -> MagicMock: ) -> MagicMock:
mock_config = MagicMock(spec=VllmConfig) mock_config = MagicMock(spec=VllmConfig)
mock_config.compilation_config = compilation_config mock_config.compilation_config = compilation_config
mock_config.scheduler_config = SchedulerConfig(max_num_seqs=max_num_seqs) mock_config.scheduler_config = SchedulerConfig(max_num_seqs=max_num_seqs)
mock_config.parallel_config = ParallelConfig() mock_config.parallel_config = ParallelConfig()
if not lora_config:
mock_config.lora_config = None
# Mimic the behavior of VllmConfig.__post_init__() # Mimic the behavior of VllmConfig.__post_init__()
if compilation_config.mode == CompilationMode.VLLM_COMPILE: if compilation_config.mode == CompilationMode.VLLM_COMPILE:
compilation_config.set_splitting_ops_for_v1() compilation_config.set_splitting_ops_for_v1()
...@@ -50,19 +53,21 @@ def _create_vllm_config( ...@@ -50,19 +53,21 @@ def _create_vllm_config(
class TestCudagraphDispatcher: class TestCudagraphDispatcher:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"case_id,cudagraph_mode_str,compilation_mode", "cudagraph_mode_str,compilation_mode,lora_config",
[ [
# Test case 0: Full CG for mixed batches, no separate routine # Test case 0: Full CG for mixed batches, no separate routine
(0, "FULL", CompilationMode.NONE), ("FULL", CompilationMode.NONE, False),
# Test case 1: Full CG for uniform batches, piecewise for mixed # Test case 1: Full CG for uniform batches, piecewise for mixed
(1, "FULL_AND_PIECEWISE", CompilationMode.NONE), ("FULL_AND_PIECEWISE", CompilationMode.NONE, False),
# Test case 2: Full CG for uniform batches, no CG for mixed # Test case 2: Full CG for uniform batches, no CG for mixed
(2, "FULL_DECODE_ONLY", CompilationMode.NONE), ("FULL_DECODE_ONLY", CompilationMode.NONE, False),
# Test case 3: PIECEWISE for all # Test case 3: PIECEWISE for all
(3, "PIECEWISE", CompilationMode.VLLM_COMPILE), ("PIECEWISE", CompilationMode.VLLM_COMPILE, False),
# Test case 4: PIECEWISE for all, specialize LoRA cases
("PIECEWISE", CompilationMode.VLLM_COMPILE, True),
], ],
) )
def test_dispatcher(self, cudagraph_mode_str, compilation_mode): def test_dispatcher(self, cudagraph_mode_str, compilation_mode, lora_config):
# Setup dispatcher # Setup dispatcher
comp_config = CompilationConfig( comp_config = CompilationConfig(
cudagraph_mode=cudagraph_mode_str, cudagraph_mode=cudagraph_mode_str,
...@@ -70,7 +75,17 @@ class TestCudagraphDispatcher: ...@@ -70,7 +75,17 @@ class TestCudagraphDispatcher:
cudagraph_capture_sizes=[1, 8], cudagraph_capture_sizes=[1, 8],
) )
config = _create_vllm_config(comp_config, max_num_seqs=8) config = _create_vllm_config(
comp_config, max_num_seqs=8, lora_config=lora_config
)
if (
cudagraph_mode_str == "FULL_AND_PIECEWISE"
and compilation_mode == CompilationMode.NONE
):
with pytest.raises(AssertionError):
dispatcher = CudagraphDispatcher(config)
return
dispatcher = CudagraphDispatcher(config) dispatcher = CudagraphDispatcher(config)
dispatcher.initialize_cudagraph_keys( dispatcher.initialize_cudagraph_keys(
cudagraph_mode=comp_config.cudagraph_mode, uniform_decode_query_len=1 cudagraph_mode=comp_config.cudagraph_mode, uniform_decode_query_len=1
...@@ -78,17 +93,24 @@ class TestCudagraphDispatcher: ...@@ -78,17 +93,24 @@ class TestCudagraphDispatcher:
# Verify the key is initialized correctly # Verify the key is initialized correctly
if cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]: if cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 2 assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == (
4 if lora_config else 2
)
else: else:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 0 assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 0
if cudagraph_mode_str not in ["NONE", "PIECEWISE"]: if cudagraph_mode_str not in ["NONE", "PIECEWISE"]:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 2 assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == (
4 if lora_config else 2
)
else: else:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 0 assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 0
# Test dispatch logic # Test dispatch logic
# 1. non-uniform batch, size in cudagraph size list # 1. non-uniform batch, size in cudagraph size list
desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False) desc_full_exact = BatchDescriptor(
num_tokens=8,
uniform_decode=False,
)
rt_mode, key = dispatcher.dispatch(desc_full_exact) rt_mode, key = dispatcher.dispatch(desc_full_exact)
if cudagraph_mode_str == "FULL": if cudagraph_mode_str == "FULL":
assert rt_mode == CUDAGraphMode.FULL assert rt_mode == CUDAGraphMode.FULL
...@@ -138,7 +160,6 @@ class TestCUDAGraphWrapper: ...@@ -138,7 +160,6 @@ class TestCUDAGraphWrapper:
self.persistent_input_buffer = torch.zeros(1, 10, device="cuda") self.persistent_input_buffer = torch.zeros(1, 10, device="cuda")
self.input_tensor = torch.randn(1, 10, device="cuda") self.input_tensor = torch.randn(1, 10, device="cuda")
@create_new_process_for_each_test("spawn")
def test_capture_and_replay(self): def test_capture_and_replay(self):
wrapper = CUDAGraphWrapper( wrapper = CUDAGraphWrapper(
self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
...@@ -192,7 +213,6 @@ class TestCUDAGraphWrapper: ...@@ -192,7 +213,6 @@ class TestCUDAGraphWrapper:
eager_output = self.model(self.input_tensor) eager_output = self.model(self.input_tensor)
torch.testing.assert_close(eager_output, output2) torch.testing.assert_close(eager_output, output2)
@create_new_process_for_each_test("spawn")
def test_bypass_on_mode_mismatch(self): def test_bypass_on_mode_mismatch(self):
wrapper = CUDAGraphWrapper( wrapper = CUDAGraphWrapper(
self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
...@@ -216,7 +236,6 @@ class TestCUDAGraphWrapper: ...@@ -216,7 +236,6 @@ class TestCUDAGraphWrapper:
mock_forward.assert_called_once() mock_forward.assert_called_once()
assert not wrapper.concrete_cudagraph_entries assert not wrapper.concrete_cudagraph_entries
@create_new_process_for_each_test("spawn")
def test_bypass_on_mode_none(self): def test_bypass_on_mode_none(self):
wrapper = CUDAGraphWrapper( wrapper = CUDAGraphWrapper(
self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
......
...@@ -109,9 +109,9 @@ combo_cases_2 = [ ...@@ -109,9 +109,9 @@ combo_cases_2 = [
@pytest.mark.parametrize( @pytest.mark.parametrize(
"backend_name,cudagraph_mode,compilation_mode,supported", combo_cases_2 "backend_name,cudagraph_mode,compilation_mode,supported", combo_cases_2
) )
def test_cudagraph_compilation_combo(combo_case): def test_cudagraph_compilation_combo(
backend_name, cudagraph_mode, compilation_mode, supported = combo_case backend_name, cudagraph_mode, compilation_mode, supported
):
env_vars = backend_configs[backend_name].env_vars env_vars = backend_configs[backend_name].env_vars
with temporary_environ(env_vars), ExitStack() as stack: with temporary_environ(env_vars), ExitStack() as stack:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment