controller.py 2.11 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
3
from typing import Dict, List, Union

from cacheflow.master.scheduler import Scheduler
4
from cacheflow.sequence import SequenceGroupInputs
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
7
8
9
10
11
12
13
14
15
16
17
from cacheflow.worker.worker import Worker


class Controller:

    def __init__(
        self,
        node_id: int,
        num_workers: int,
        model_name: str,
        block_size: int,
        num_gpu_blocks: int,
        num_cpu_blocks: int,
18
19
        dtype: str,
        seed: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
    ) -> None:
        self.node_id = node_id
        self.num_workers = num_workers
        self.model_name = model_name
        self.block_size = block_size
        self.num_gpu_blocks = num_gpu_blocks
        self.num_cpu_blocks = num_cpu_blocks

        # Which pipeline stage is this node assigned to?
        self.is_first_stage = node_id == 0
        self.is_last_stage = False

        self.workers: List[Worker] = []
        for i in range(num_workers):
            worker = Worker(
                worker_id=node_id + i,
                gpu_id=i,
                model_name=model_name,
                block_size=block_size,
                num_gpu_blocks=num_gpu_blocks,
                num_cpu_blocks=num_cpu_blocks,
Woosuk Kwon's avatar
Woosuk Kwon committed
41
                dtype=dtype,
42
                seed=seed,
Woosuk Kwon's avatar
Woosuk Kwon committed
43
44
45
46
47
48
49
50
51
52
53
54
            )
            self.workers.append(worker)

    def set_next(
        self,
        next_node: Union['Controller', 'Scheduler'],
    ) -> None:
        self.next_node = next_node
        self.is_last_stage = isinstance(next_node, Scheduler)

    def execute_stage(
        self,
55
        input_seq_groups: List[SequenceGroupInputs],
Woosuk Kwon's avatar
Woosuk Kwon committed
56
57
        blocks_to_swap_in: Dict[int, int],
        blocks_to_swap_out: Dict[int, int],
58
        blocks_to_copy: Dict[int, List[int]],
Woosuk Kwon's avatar
Woosuk Kwon committed
59
60
61
62
63
    ) -> None:
        # FIXME: Support tensor parallelism.
        assert len(self.workers) == 1
        worker = self.workers[0]
        output = worker.execute_stage(
64
            input_seq_groups,
Woosuk Kwon's avatar
Woosuk Kwon committed
65
66
67
68
69
70
71
72
73
74
            blocks_to_swap_in,
            blocks_to_swap_out,
            blocks_to_copy,
        )

        if self.is_last_stage:
            self.next_node.post_step(output)
        else:
            # TODO: Support pipeline parallelism.
            assert False