controller.py 3.24 KB
Newer Older
1
from typing import List, Optional, Tuple, Union
Zhuohan Li's avatar
Zhuohan Li committed
2

3
4
5
6
try:
    import ray
except ImportError:
    ray = None
Woosuk Kwon's avatar
Woosuk Kwon committed
7

8
from cacheflow.core.scheduler import Scheduler
Woosuk Kwon's avatar
Woosuk Kwon committed
9
10
11
from cacheflow.worker.worker import Worker


Zhuohan Li's avatar
Zhuohan Li committed
12
13
14
DeviceID = Tuple[int, str, int] # rank, node resource (node IP), device id


Woosuk Kwon's avatar
Woosuk Kwon committed
15
16
17
18
class Controller:

    def __init__(
        self,
Zhuohan Li's avatar
Zhuohan Li committed
19
20
21
22
23
24
        stage_id: int,
        stage_devices: List[DeviceID],
        world_size: int,
        tensor_parallel_size: int,
        pipeline_parallel_size: int,
        distributed_init_method: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
27
28
        model_name: str,
        block_size: int,
        num_gpu_blocks: int,
        num_cpu_blocks: int,
29
30
        dtype: str,
        seed: int,
31
        cache_dir: Optional[str],
32
        use_dummy_weights: bool,
33
        use_np_cache: bool,
34
        max_num_batched_tokens: int,
35
        use_ray: bool,
Woosuk Kwon's avatar
Woosuk Kwon committed
36
    ) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
37
38
        self.stage_id = stage_id
        self.stage_devices = stage_devices
Woosuk Kwon's avatar
Woosuk Kwon committed
39
40
41
42
        self.model_name = model_name
        self.block_size = block_size
        self.num_gpu_blocks = num_gpu_blocks
        self.num_cpu_blocks = num_cpu_blocks
43
        self.use_ray = use_ray
Woosuk Kwon's avatar
Woosuk Kwon committed
44
45

        # Which pipeline stage is this node assigned to?
Zhuohan Li's avatar
Zhuohan Li committed
46
        self.is_first_stage = stage_id == 0
Woosuk Kwon's avatar
Woosuk Kwon committed
47
48
49
        self.is_last_stage = False

        self.workers: List[Worker] = []
Zhuohan Li's avatar
Zhuohan Li committed
50
        for rank, node_resource, device_id in stage_devices:
51
52
53
54
55
56
57
            if self.use_ray:
                worker_cls = ray.remote(num_cpus=0,
                                        num_gpus=1,
                                        resources={node_resource: 1e-5})(Worker).remote
            else:
                worker_cls = Worker
            worker = worker_cls(
Woosuk Kwon's avatar
Woosuk Kwon committed
58
59
60
61
                model_name=model_name,
                block_size=block_size,
                num_gpu_blocks=num_gpu_blocks,
                num_cpu_blocks=num_cpu_blocks,
Woosuk Kwon's avatar
Woosuk Kwon committed
62
                dtype=dtype,
63
                seed=seed,
Zhuohan Li's avatar
Zhuohan Li committed
64
65
66
67
68
                distributed_init_method=distributed_init_method,
                rank=rank,
                world_size=world_size,
                tensor_parallel_size=tensor_parallel_size,
                pipeline_parallel_size=pipeline_parallel_size,
69
                cache_dir=cache_dir,
70
                use_dummy_weights=use_dummy_weights,
71
                use_np_cache=use_np_cache,
72
                max_num_batched_tokens=max_num_batched_tokens,
Woosuk Kwon's avatar
Woosuk Kwon committed
73
74
75
76
77
78
79
80
81
82
            )
            self.workers.append(worker)

    def set_next(
        self,
        next_node: Union['Controller', 'Scheduler'],
    ) -> None:
        self.next_node = next_node
        self.is_last_stage = isinstance(next_node, Scheduler)

83
    def execute_stage(self, *args, **kwargs) -> None:
84
        all_outputs = []
Zhuohan Li's avatar
Zhuohan Li committed
85
        for worker in self.workers:
86
87
            executor = (worker.execute_stage.remote
                        if self.use_ray else worker.execute_stage)
88
            output = executor(*args, **kwargs)
89
90
91
92
            all_outputs.append(output)

        if self.use_ray:
            all_outputs = ray.get(all_outputs)
Zhuohan Li's avatar
Zhuohan Li committed
93
94
95
96
97

        # Make sure all workers have the same results.
        output = all_outputs[0]
        for other_output in all_outputs[1:]:
            assert output == other_output
Woosuk Kwon's avatar
Woosuk Kwon committed
98
99
100
101
102
103

        if self.is_last_stage:
            self.next_node.post_step(output)
        else:
            # TODO: Support pipeline parallelism.
            assert False