controller.py 5.52 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import sys
from typing import Literal

import pytest

from dynamo.planner import LocalConnector
from dynamo.runtime import DistributedRuntime, dynamo_worker

pytestmark = pytest.mark.skip("This is not a test file")

ComponentType = Literal["VllmWorker", "PrefillWorker"]
VALID_COMPONENTS = ["VllmWorker", "PrefillWorker"]


async def test_state_management(connector: LocalConnector) -> bool:
    """Test state file operations."""
    print("\n=== Testing State Management ===")
    try:
        # Test load state
        state = await connector._load_state()
        print("✓ Load state successful")

        # Test save state (with a copy)
        success = await connector._save_state(state)
        print(
            f"{'✓' if success else '✗'} Save state {'successful' if success else 'failed'}"
        )

        return True
    except Exception as e:
        print(f"✗ State management test failed: {e}")
        return False


async def test_add_component(
    connector: LocalConnector, component: ComponentType
) -> bool:
    """Test adding a component."""
    print(f"\n=== Testing Add Component: {component} ===")
    try:
        success = await connector.add_component(component)
        print(
            f"{'✓' if success else '✗'} Add {component} {'successful' if success else 'failed'}"
        )
        return success
    except Exception as e:
        print(f"✗ Add {component} test failed: {e}")
        return False


async def test_remove_component(
    connector: LocalConnector, component: ComponentType
) -> bool:
    """Test removing a component."""
    print(f"\n=== Testing Remove Component: {component} ===")
    try:
        state = await connector._load_state()
        base_name = f"{connector.namespace}_{component}_"

        # Find all components with numbered suffixes
        matching_components = []
        for watcher_name in state["components"].keys():
            if watcher_name.startswith(base_name):
                try:
                    suffix = int(watcher_name.replace(base_name, ""))
                    matching_components.append((suffix, watcher_name))
                except ValueError:
                    continue

        if not matching_components:
            base_component = f"{connector.namespace}_{component}"
            if base_component in state["components"]:
                success = await connector.remove_component(component)
                print(
                    f"{'✓' if success else '✗'} Remove {component} {'successful' if success else 'failed'}"
                )
                return success
            else:
                print(f"✗ No {component} components found to remove")
                return False

        # Remember which watcher we're removing
        highest_suffix = max(suffix for suffix, _ in matching_components)
        target_component = f"{base_name}{highest_suffix}"

        success = await connector.remove_component(component)

        # New verification logic that handles both numbered and base watchers
        if success:
            new_state = await connector._load_state()

            # For numbered watchers (with suffix > 0)
            if highest_suffix > 0:
                # Success if the component is completely removed
                if target_component not in new_state["components"]:
                    print(f"✓ Successfully removed {target_component}")
                    return True
                else:
                    print(f"✗ Failed to remove {target_component} from state")
                    return False
            # For base watchers (no suffix)
            else:
                base_component = f"{connector.namespace}_{component}"
                if base_component in new_state["components"]:
                    resources = new_state["components"][base_component].get(
                        "resources", {}
                    )
                    if not resources.get("allocated_gpus"):
                        print(f"✓ Successfully cleared resources for {base_component}")
                        return True
                    else:
                        print(f"✗ Failed to clear resources for {base_component}")
                        return False

            # If we get here, neither condition was met
            print(f"✗ Unexpected state after removing {component}")
            return False

        print(f"✗ Failed to remove {component}")
        return False

    except Exception as e:
        print(f"✗ Remove {component} test failed: {e}")
        return False


@dynamo_worker()
async def main(runtime: DistributedRuntime):
    connector = LocalConnector("dynamo", runtime)

    await connector.add_component("PrefillWorker")
    await connector.add_component("VllmWorker")
    await connector.remove_component("VllmWorker")
    await connector.remove_component("PrefillWorker")


if __name__ == "__main__":
    sys.exit(asyncio.run(main()))