controller.py 3.53 KB
Newer Older
Zhuohan Li's avatar
Zhuohan Li committed
1
2
from typing import Dict, List, Union, Tuple

3
4
5
6
try:
    import ray
except ImportError:
    ray = None
Woosuk Kwon's avatar
Woosuk Kwon committed
7
8

from cacheflow.master.scheduler import Scheduler
9
from cacheflow.sequence import SequenceGroupInputs
Woosuk Kwon's avatar
Woosuk Kwon committed
10
11
12
from cacheflow.worker.worker import Worker


Zhuohan Li's avatar
Zhuohan Li committed
13
14
15
DeviceID = Tuple[int, str, int] # rank, node resource (node IP), device id


Woosuk Kwon's avatar
Woosuk Kwon committed
16
17
18
19
class Controller:

    def __init__(
        self,
Zhuohan Li's avatar
Zhuohan Li committed
20
21
22
23
24
25
        stage_id: int,
        stage_devices: List[DeviceID],
        world_size: int,
        tensor_parallel_size: int,
        pipeline_parallel_size: int,
        distributed_init_method: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
26
27
28
29
        model_name: str,
        block_size: int,
        num_gpu_blocks: int,
        num_cpu_blocks: int,
30
31
        dtype: str,
        seed: int,
Zhuohan Li's avatar
Zhuohan Li committed
32
        model_path: str,
33
        use_dummy_weights: bool,
34
        max_num_batched_tokens: int,
35
        use_ray: bool,
Woosuk Kwon's avatar
Woosuk Kwon committed
36
    ) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
37
38
        self.stage_id = stage_id
        self.stage_devices = stage_devices
Woosuk Kwon's avatar
Woosuk Kwon committed
39
40
41
42
        self.model_name = model_name
        self.block_size = block_size
        self.num_gpu_blocks = num_gpu_blocks
        self.num_cpu_blocks = num_cpu_blocks
43
        self.use_ray = use_ray
Woosuk Kwon's avatar
Woosuk Kwon committed
44
45

        # Which pipeline stage is this node assigned to?
Zhuohan Li's avatar
Zhuohan Li committed
46
        self.is_first_stage = stage_id == 0
Woosuk Kwon's avatar
Woosuk Kwon committed
47
48
49
        self.is_last_stage = False

        self.workers: List[Worker] = []
Zhuohan Li's avatar
Zhuohan Li committed
50
        for rank, node_resource, device_id in stage_devices:
51
52
53
54
55
56
57
            if self.use_ray:
                worker_cls = ray.remote(num_cpus=0,
                                        num_gpus=1,
                                        resources={node_resource: 1e-5})(Worker).remote
            else:
                worker_cls = Worker
            worker = worker_cls(
Woosuk Kwon's avatar
Woosuk Kwon committed
58
59
60
61
                model_name=model_name,
                block_size=block_size,
                num_gpu_blocks=num_gpu_blocks,
                num_cpu_blocks=num_cpu_blocks,
Woosuk Kwon's avatar
Woosuk Kwon committed
62
                dtype=dtype,
63
                seed=seed,
Zhuohan Li's avatar
Zhuohan Li committed
64
65
66
67
68
69
                distributed_init_method=distributed_init_method,
                rank=rank,
                world_size=world_size,
                tensor_parallel_size=tensor_parallel_size,
                pipeline_parallel_size=pipeline_parallel_size,
                model_path=model_path,
70
                use_dummy_weights=use_dummy_weights,
71
                max_num_batched_tokens=max_num_batched_tokens,
Woosuk Kwon's avatar
Woosuk Kwon committed
72
73
74
75
76
77
78
79
80
81
82
83
            )
            self.workers.append(worker)

    def set_next(
        self,
        next_node: Union['Controller', 'Scheduler'],
    ) -> None:
        self.next_node = next_node
        self.is_last_stage = isinstance(next_node, Scheduler)

    def execute_stage(
        self,
84
        input_seq_groups: List[SequenceGroupInputs],
Woosuk Kwon's avatar
Woosuk Kwon committed
85
86
        blocks_to_swap_in: Dict[int, int],
        blocks_to_swap_out: Dict[int, int],
87
        blocks_to_copy: Dict[int, List[int]],
Woosuk Kwon's avatar
Woosuk Kwon committed
88
    ) -> None:
89
        all_outputs = []
Zhuohan Li's avatar
Zhuohan Li committed
90
        for worker in self.workers:
91
92
93
            executor = (worker.execute_stage.remote
                        if self.use_ray else worker.execute_stage)
            output = executor(
Zhuohan Li's avatar
Zhuohan Li committed
94
95
96
97
98
                input_seq_groups,
                blocks_to_swap_in,
                blocks_to_swap_out,
                blocks_to_copy,
            )
99
100
101
102
            all_outputs.append(output)

        if self.use_ray:
            all_outputs = ray.get(all_outputs)
Zhuohan Li's avatar
Zhuohan Li committed
103
104
105
106
107

        # Make sure all workers have the same results.
        output = all_outputs[0]
        for other_output in all_outputs[1:]:
            assert output == other_output
Woosuk Kwon's avatar
Woosuk Kwon committed
108
109
110
111
112
113

        if self.is_last_stage:
            self.next_node.post_step(output)
        else:
            # TODO: Support pipeline parallelism.
            assert False