tpu_worker.py 12.4 KB
Newer Older
1
import os
2
from typing import List, Optional, Tuple, Union
3
4
5

import torch
import torch_xla.core.xla_model as xm
6
import torch_xla.experimental.dynamo_set_buffer_donor  # noqa: F401
7
8
9
10
import torch_xla.runtime as xr

import vllm.envs as envs
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
11
                         MultiModalConfig, ParallelConfig, SchedulerConfig)
12
13
14
15
from vllm.distributed import (ensure_model_parallel_initialized,
                              init_distributed_environment)
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
16
from vllm.sequence import ExecuteModelRequest
17
18
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
from vllm.worker.tpu_model_runner import TPUModelRunner
19
20
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
                                     LoraNotSupportedWorkerBase, WorkerInput)
21
22
23
24

logger = init_logger(__name__)


25
class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
26
27
28
29
30
31
32
33
34

    def __init__(
        self,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
        device_config: DeviceConfig,
        cache_config: CacheConfig,
        load_config: LoadConfig,
35
        multimodal_config: Optional[MultiModalConfig],
36
37
38
        local_rank: int,
        rank: int,
        distributed_init_method: str,
39
        is_driver_worker: bool,
40
41
42
    ) -> None:
        self.model_config = model_config
        self.parallel_config = parallel_config
43
        self.parallel_config.rank = rank
44
45
46
47
        self.scheduler_config = scheduler_config
        self.device_config = device_config
        self.cache_config = cache_config
        self.load_config = load_config
48
        self.multimodal_config = multimodal_config
49
50
51
        self.local_rank = local_rank
        self.rank = rank
        self.distributed_init_method = distributed_init_method
52
        self.is_driver_worker = is_driver_worker
53
54
55
56
57
58
59
60

        assert self.device_config.device_type == "tpu"
        if self.cache_config.cache_dtype == "auto":
            self.cache_dtype = self.model_config.dtype
        else:
            self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
                self.cache_config.cache_dtype]

61
62
63
64
65
66
67
68
69
        self.model_runner: TPUModelRunner = TPUModelRunner(
            model_config,
            parallel_config,
            scheduler_config,
            device_config,
            cache_config,
            load_config,
            multimodal_config,
            is_driver_worker=is_driver_worker)
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102

    def init_device(self) -> None:
        os.environ["PJRT_DEVICE"] = "TPU"
        self.device = xm.xla_device()
        self.device_config.device = self.device
        torch.set_grad_enabled(False)
        torch.set_default_dtype(self.model_config.dtype)

        # NOTE(woosuk): This is just a hack to initialize the TP group.
        # This cannot perform the actual communication ops.
        init_distributed_environment(
            world_size=self.parallel_config.world_size,
            rank=self.rank,
            local_rank=self.local_rank,
            distributed_init_method=self.distributed_init_method,
            backend="gloo",
        )
        ensure_model_parallel_initialized(
            self.parallel_config.tensor_parallel_size,
            self.parallel_config.pipeline_parallel_size)

        # Set random seed.
        set_random_seed(self.model_config.seed)
        xm.set_rng_state(self.model_config.seed, self.device)

        # Increase the cache size limit, which is the maximum number of
        # dynamo graphs that can be compiled.
        # NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and
        # 30-40 graphs for decode. 128 is an arbitrary safe number.
        torch._dynamo.config.cache_size_limit = 128
        # Use persistent cache to avoid XLA recompilation.
        # NOTE(woosuk): This does not completely eliminate the recompilation
        # overhead because dynamo does not cache the compiled results.
103
        xr.initialize_cache(envs.VLLM_XLA_CACHE_PATH, readonly=False)
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122

    def load_model(self):
        self.model_runner.load_model()

    def determine_num_available_blocks(self) -> Tuple[int, int]:
        num_layers = self.model_config.get_num_layers(self.parallel_config)
        head_size = self.model_config.get_head_size()
        num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)

        kv_caches = [(None, None) for _ in range(num_layers)]
        self.model_runner._dummy_run(
            batch_size=1,
            seq_len=self.scheduler_config.max_num_batched_tokens,
            kv_caches=kv_caches,
            is_prompt=True,
        )
        # Synchronize before measuring the memory usage.
        xm.wait_device_ops()

123
124
125
126
127
128
        dtype_btyes = get_dtype_size(self.cache_dtype)
        block_size = self.cache_config.block_size
        block_size_bytes = (dtype_btyes * block_size * num_layers * 2 *
                            head_size * num_kv_heads)

        # Calculate the TPU KV cache size based on profiling.
129
        m = xm.get_memory_info(self.device)
130
131
132
133
        total_memory_size = m["bytes_limit"]
        usable_memory_size = int(total_memory_size *
                                 self.cache_config.gpu_memory_utilization)
        profiled = m["bytes_used"]  # Weights + intermediate activations.
134
135
        tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0)
        num_tpu_blocks = tpu_kv_cache_bytes // block_size_bytes
136
        num_tpu_blocks = (num_tpu_blocks // 8) * 8  # Round down to 8.
137
138
139
140
141
142

        # Calculate the CPU KV cache size based on the config.
        num_cpu_blocks = (self.cache_config.swap_space_bytes //
                          block_size_bytes)
        num_cpu_blocks = (num_cpu_blocks // 8) * 8  # Round down to 8.
        return num_tpu_blocks, num_cpu_blocks
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157

    def initialize_cache(
        self,
        num_gpu_blocks: int,
        num_cpu_blocks: int,
    ) -> None:
        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks
        self.block_size = self.cache_config.block_size

        dtype = self.cache_dtype
        num_layers = self.model_config.get_num_layers(self.parallel_config)
        num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
        head_size = self.model_config.get_head_size()

158
159
        self.cpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = []
        self.tpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = []
160
161
        tpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
            num_gpu_blocks, self.block_size, num_kv_heads, head_size)
162
163
        cpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
            num_cpu_blocks, self.block_size, num_kv_heads, head_size)
164
        for _ in range(num_layers):
165
166
167
168
169
            tpu_k_cache = torch.zeros(tpu_cache_shape,
                                      dtype=dtype,
                                      device=self.device)
            tpu_v_cache = torch.zeros_like(tpu_k_cache)
            self.tpu_cache.append((tpu_k_cache, tpu_v_cache))
170
171
172
173
            cpu_k_cache = torch.zeros(cpu_cache_shape,
                                      dtype=dtype,
                                      device="cpu")
            cpu_v_cache = torch.zeros_like(cpu_k_cache)
174
            self.cpu_cache.append((cpu_k_cache, cpu_v_cache))
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
        self._warmup_model()

    def _warmup_model(self) -> None:
        # FIXME(woosuk): Here we are abusing `enforce_eager` which is defined
        # for CUDA graphs. We should refactor this part.
        if not self.model_config.enforce_eager:
            # Warm up the model with all possible input shapes so that
            # compilation never happens during the actual execution.
            # This may take ~30 mins for the first run and ~20 mins for the
            # subsequent runs.
            # If `enforce_eager` is True, the ahead-of-time compilation is
            # skipped and the compilation happens during the actual execution,
            # which is bad for performance but useful for development.
            self.model_runner.warmup_model(self.tpu_cache)

    def get_cache_block_size_bytes(self) -> int:
        head_size = self.model_config.get_head_size()
        num_heads = self.model_config.get_num_kv_heads(self.parallel_config)
        num_layers = self.model_config.get_num_layers(self.parallel_config)

        key_cache_block = self.cache_config.block_size * num_heads * head_size
        value_cache_block = key_cache_block
        total = num_layers * (key_cache_block + value_cache_block)
        dtype_size = get_dtype_size(self.cache_dtype)
        return dtype_size * total

201
202
203
204
205
206
207
208
209
210
211
212
    @property
    def do_metadata_broadcast(self) -> bool:
        # TODO(woosuk): Support TP.
        return False

    @property
    def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
        # NOTE(woosuk): This assumes virtual_engine == 0, i.e., no pipeline
        # parallelism.
        return [self.tpu_cache]

    def prepare_worker_input(
213
        self,
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
        execute_model_req: ExecuteModelRequest,
    ) -> WorkerInput:
        virtual_engine = execute_model_req.virtual_engine
        num_seq_groups = len(execute_model_req.seq_group_metadata_list)
        blocks_to_swap_in = _make_src_to_dst(
            execute_model_req.blocks_to_swap_in, "cpu", self.device)
        blocks_to_swap_out = _make_src_to_dst(
            execute_model_req.blocks_to_swap_out, self.device, "cpu")
        blocks_to_copy = _make_src_to_dst(execute_model_req.blocks_to_copy,
                                          self.device, self.device)
        return WorkerInput(
            num_seq_groups=num_seq_groups,
            blocks_to_swap_in=blocks_to_swap_in,
            blocks_to_swap_out=blocks_to_swap_out,
            blocks_to_copy=blocks_to_copy,
            virtual_engine=virtual_engine,
230
        )
231
232
233
234

    def execute_worker(self, worker_input: WorkerInput) -> None:
        virtual_engine = worker_input.virtual_engine
        assert virtual_engine == 0
235
236
237
        attn_backend = self.model_runner.attn_backend
        num_layers = self.model_config.get_num_layers(self.parallel_config)

238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
        # Issue cache operations.
        if worker_input.blocks_to_swap_in is not None:
            src_indices, dst_indices = worker_input.blocks_to_swap_in
            if src_indices.numel() > 0:
                # Swap from CPU to TPU.
                for i in range(num_layers):
                    tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
                    cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
                    k = cpu_k_cache[:, src_indices].to(self.device)
                    v = cpu_v_cache[:, src_indices].to(self.device)
                    _insert_kv(k, v, dst_indices, tpu_k_cache, tpu_v_cache)

        if worker_input.blocks_to_swap_out is not None:
            src_indices, dst_indices = worker_input.blocks_to_swap_out
            if src_indices.numel() > 0:
                # Swap from TPU to CPU.
                for i in range(num_layers):
                    tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
                    cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
                    cpu_k_cache[:, dst_indices] = tpu_k_cache[:, src_indices]
                    cpu_v_cache[:, dst_indices] = tpu_v_cache[:, src_indices]

        if worker_input.blocks_to_copy is not None:
            src_indices, dst_indices = worker_input.blocks_to_copy
            if src_indices.numel() > 0:
                attn_backend.copy_blocks(self.tpu_cache,
                                         (src_indices, dst_indices))
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280


def _make_src_to_dst(
    mapping: List[Tuple[int, int]],
    src_device: Union[torch.device, str],
    dst_device: Union[torch.device, str],
) -> Tuple[torch.Tensor, torch.Tensor]:
    src_indices = [i for i, _ in mapping]
    dst_indices = [i for _, i in mapping]
    src_indices = torch.tensor(src_indices,
                               device=src_device,
                               dtype=torch.int64)
    dst_indices = torch.tensor(dst_indices,
                               device=dst_device,
                               dtype=torch.int64)
    return src_indices, dst_indices
281
282
283
284
285
286
287
288
289
290
291
292
293
294


@torch.compile(backend="openxla")
def _insert_kv(
    k: torch.Tensor,
    v: torch.Tensor,
    indices: torch.Tensor,
    tpu_k_cache: torch.Tensor,
    tpu_v_cache: torch.Tensor,
) -> None:
    torch.ops.xla.dynamo_set_buffer_donor_(tpu_k_cache, True)
    torch.ops.xla.dynamo_set_buffer_donor_(tpu_v_cache, True)
    tpu_k_cache[:, indices] = k
    tpu_v_cache[:, indices] = v