controller.py 4.19 KB
Newer Older
1
from typing import List, Optional, Tuple, Union
Zhuohan Li's avatar
Zhuohan Li committed
2

3
4
5
6
try:
    import ray
except ImportError:
    ray = None
Woosuk Kwon's avatar
Woosuk Kwon committed
7

8
from cacheflow.core.scheduler import Scheduler
Woosuk Kwon's avatar
Woosuk Kwon committed
9
10
11
from cacheflow.worker.worker import Worker


Zhuohan Li's avatar
Zhuohan Li committed
12
13
14
DeviceID = Tuple[int, str, int] # rank, node resource (node IP), device id


Woosuk Kwon's avatar
Woosuk Kwon committed
15
16
17
18
class Controller:

    def __init__(
        self,
Zhuohan Li's avatar
Zhuohan Li committed
19
20
21
22
23
24
        stage_id: int,
        stage_devices: List[DeviceID],
        world_size: int,
        tensor_parallel_size: int,
        pipeline_parallel_size: int,
        distributed_init_method: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
25
        model_name: str,
26
27
        dtype: str,
        seed: int,
28
        cache_dir: Optional[str],
29
        use_dummy_weights: bool,
30
        use_np_cache: bool,
31
        max_num_batched_tokens: int,
32
        max_num_sequences: int,
33
        use_ray: bool,
Woosuk Kwon's avatar
Woosuk Kwon committed
34
    ) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
35
36
        self.stage_id = stage_id
        self.stage_devices = stage_devices
Woosuk Kwon's avatar
Woosuk Kwon committed
37
        self.model_name = model_name
38
        self.use_ray = use_ray
Woosuk Kwon's avatar
Woosuk Kwon committed
39
40

        # Which pipeline stage is this node assigned to?
Zhuohan Li's avatar
Zhuohan Li committed
41
        self.is_first_stage = stage_id == 0
Woosuk Kwon's avatar
Woosuk Kwon committed
42
43
44
        self.is_last_stage = False

        self.workers: List[Worker] = []
Zhuohan Li's avatar
Zhuohan Li committed
45
        for rank, node_resource, device_id in stage_devices:
46
47
48
49
50
51
52
            if self.use_ray:
                worker_cls = ray.remote(num_cpus=0,
                                        num_gpus=1,
                                        resources={node_resource: 1e-5})(Worker).remote
            else:
                worker_cls = Worker
            worker = worker_cls(
Woosuk Kwon's avatar
Woosuk Kwon committed
53
                model_name=model_name,
Woosuk Kwon's avatar
Woosuk Kwon committed
54
                dtype=dtype,
55
                seed=seed,
Zhuohan Li's avatar
Zhuohan Li committed
56
57
58
59
60
                distributed_init_method=distributed_init_method,
                rank=rank,
                world_size=world_size,
                tensor_parallel_size=tensor_parallel_size,
                pipeline_parallel_size=pipeline_parallel_size,
61
                cache_dir=cache_dir,
62
                use_dummy_weights=use_dummy_weights,
63
                use_np_cache=use_np_cache,
64
                max_num_batched_tokens=max_num_batched_tokens,
65
                max_num_sequences=max_num_sequences,
Woosuk Kwon's avatar
Woosuk Kwon committed
66
67
68
            )
            self.workers.append(worker)

69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
    def get_num_available_blocks(self, block_size: int, cpu_swap_space: int,
                                 gpu_memory_utilization: float) -> List[Tuple[int, int]]:
        all_worker_results = []
        for worker in self.workers:
            executor = worker.get_num_available_blocks
            if self.use_ray:
                executor = executor.remote

            result = executor(
                block_size,
                cpu_swap_space,
                gpu_memory_utilization,
            )
            all_worker_results.append(result)
        if self.use_ray:
            all_worker_results = ray.get(all_worker_results)
        return all_worker_results

    def init_cache_engine(self, block_size: int, num_gpu_blocks: int,
                          num_cpu_blocks: int):
        all_worker_futures = []
        for worker in self.workers:
            executor = worker.init_cache_engine
            if self.use_ray:
                executor = executor.remote
            future = executor(
                block_size,
                num_gpu_blocks,
                num_cpu_blocks,
            )
            all_worker_futures.append(future)
        if self.use_ray:
            ray.get(all_worker_futures)

Woosuk Kwon's avatar
Woosuk Kwon committed
103
104
105
106
107
108
109
    def set_next(
        self,
        next_node: Union['Controller', 'Scheduler'],
    ) -> None:
        self.next_node = next_node
        self.is_last_stage = isinstance(next_node, Scheduler)

110
    def execute_stage(self, *args, **kwargs) -> None:
111
        all_outputs = []
Zhuohan Li's avatar
Zhuohan Li committed
112
        for worker in self.workers:
113
114
            executor = (worker.execute_stage.remote
                        if self.use_ray else worker.execute_stage)
115
            output = executor(*args, **kwargs)
116
117
118
119
            all_outputs.append(output)

        if self.use_ray:
            all_outputs = ray.get(all_outputs)
Zhuohan Li's avatar
Zhuohan Li committed
120
121
122
123
124

        # Make sure all workers have the same results.
        output = all_outputs[0]
        for other_output in all_outputs[1:]:
            assert output == other_output
Woosuk Kwon's avatar
Woosuk Kwon committed
125
126
127
128
129
130

        if self.is_last_stage:
            self.next_node.post_step(output)
        else:
            # TODO: Support pipeline parallelism.
            assert False