tpu_worker.py 14.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import os
from typing import List, Optional, Tuple, Union

import torch
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
import torch_xla.runtime as xr

import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
                              init_distributed_environment)
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache, get_dtype_size
from vllm.worker.tpu_model_runner import ExecutionMode, TPUModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
                                     LoRANotSupportedWorkerBase, WorkerBase,
                                     WorkerInput)

logger = init_logger(__name__)


class TPUWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase):

    def __init__(
        self,
        vllm_config: VllmConfig,
        local_rank: int,
        rank: int,
        distributed_init_method: str,
        is_driver_worker: bool,
    ) -> None:
        WorkerBase.__init__(self, vllm_config=vllm_config)
        self.parallel_config.rank = rank
        self.local_rank = local_rank
        self.rank = rank
        self.distributed_init_method = distributed_init_method
        self.is_driver_worker = is_driver_worker

        assert self.device_config.device_type == "tpu"
        if self.cache_config.cache_dtype == "auto":
            self.cache_dtype = self.model_config.dtype
        else:
            self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
                self.cache_config.cache_dtype]

        self.model_runner: TPUModelRunner = TPUModelRunner(
            vllm_config=vllm_config, is_driver_worker=is_driver_worker)

        if self.model_config.seed is None:
            self.model_config.seed = 0

        if vllm_config.lora_config is not None:
            raise NotImplementedError(
                "The V0 TPU backend doesn't support LoRA serving")

    def init_device(self) -> None:
        os.environ["PJRT_DEVICE"] = "TPU"
        torch.set_grad_enabled(False)
        torch.set_default_dtype(self.model_config.dtype)

        # NOTE(woosuk): This is just to initialize the TP group and broadcast
        # the input objects on CPU. The all-reduce and all-gather ops on TPU
        # are invoked by `xm.all_reduce` and `xm.all_gather` which use their
        # own context.
        init_distributed_environment(
            world_size=self.parallel_config.world_size,
            rank=self.rank,
            local_rank=self.local_rank,
            distributed_init_method=self.distributed_init_method,
            backend="gloo",
        )
        ensure_model_parallel_initialized(
            self.parallel_config.tensor_parallel_size,
            self.parallel_config.pipeline_parallel_size)

        # Device initialization should happen after initializing the distributed
        # runtime.
        self.device = xm.xla_device()
        self.device_config.device = self.device

        # Set random seed.
        set_random_seed(self.model_config.seed)
        xm.set_rng_state(self.model_config.seed, self.device)

        # Increase the cache size limit, which is the maximum number of
        # dynamo graphs that can be compiled.
        # NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and
        # 30-40 graphs for decode. 128 is an arbitrary safe number.
        torch._dynamo.config.cache_size_limit = 128
        # Use persistent cache to avoid XLA recompilation.
        # NOTE(woosuk): Set per-rank cache path since different ranks
        # can have slightly different XLA graphs.
        world_size = self.parallel_config.world_size
        rank = xr.global_ordinal()
        # The PyTorch/XLA compilation cache uses the Torch IR to generate keys.
        # Consequently, changes in optimization flags, which affect compilation
        # results, don't change the cache key. This can result in the wrong
        # compilation being used. To prevent this, disabling the XLA compilation
        # cache during development is recommended.We can disable it by
        # `export VLLM_XLA_CACHE_PATH=`
        if envs.VLLM_XLA_CACHE_PATH:
            per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH,
                                         f"tp{world_size}_rank{rank}")
            xr.initialize_cache(per_rank_path, readonly=False)

        self.profiler = None
        if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
            # For TPU, we can only have 1 active profiler session for 1 profiler
            # server. So we only profile on rank0.
            self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR
            logger.info("Profiling enabled. Traces will be saved to: %s",
                        self.profile_dir)
            self.profiler = xp.start_server(9012)

    def start_profile(self):
        if self.rank < 1:
            if self.profiler is None:
                raise RuntimeError("Profiler is not enabled.")
            xp.start_trace(self.profile_dir)

    def stop_profile(self):
        if self.rank < 1:
            if self.profiler is None:
                raise RuntimeError("Profiler is not enabled.")
            xp.stop_trace()

    def load_model(self):
        self.model_runner.load_model()

    def determine_num_available_blocks(self) -> Tuple[int, int]:
        num_layers = self.model_config.get_num_layers(self.parallel_config)
        head_size = self.model_config.get_head_size()
        num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)

        # use an empty tensor instead of `None`` to force Dynamo to pass
        # it by reference, rather by specializing on the value ``None``.
        # the `dtype` argument does not matter, and we use `float32` as
        # a placeholder (it has wide hardware support).
        kv_caches = [(torch.tensor([], dtype=torch.float32,
                                   device=self.device),
                      torch.tensor([], dtype=torch.float32,
                                   device=self.device))
                     for _ in range(num_layers)]
        bind_kv_cache(self.compilation_config.static_forward_context,
                      [kv_caches])
        self.model_runner._dummy_run(
            batch_size=1,
            seq_len=self.scheduler_config.max_num_batched_tokens,
            kv_caches=kv_caches,
            exec_mode=ExecutionMode.PREFILL,
        )
        # Synchronize before measuring the memory usage.
        xm.wait_device_ops()

        # Get the maximum amount of memory used by the model weights and
        # intermediate activations.
        m = xm.get_memory_info(self.device)
        total_memory_size = m["bytes_limit"]
        profiled = m["peak_bytes_used"]  # Weights + intermediate activations.

        # Calculate the TPU KV cache size based on profiling.
        usable_memory_size = int(total_memory_size *
                                 self.cache_config.gpu_memory_utilization)
        tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0)
        dtype_bytes = get_dtype_size(self.cache_dtype)
        block_size_bytes = (dtype_bytes * self.cache_config.block_size *
                            num_layers * 2 * head_size * num_kv_heads)
        num_tpu_blocks = tpu_kv_cache_bytes // block_size_bytes
        num_tpu_blocks = (num_tpu_blocks // 8) * 8  # Round down to 8.

        # Calculate the CPU KV cache size based on the config.
        num_cpu_blocks = int(self.cache_config.swap_space_bytes //
                             block_size_bytes)
        num_cpu_blocks = (num_cpu_blocks // 8) * 8  # Round down to 8.
        return num_tpu_blocks, num_cpu_blocks

    def initialize_cache(
        self,
        num_gpu_blocks: int,
        num_cpu_blocks: int,
    ) -> None:
        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks
        self.block_size = self.cache_config.block_size

        dtype = self.cache_dtype
        num_layers = self.model_config.get_num_layers(self.parallel_config)
        num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
        head_size = self.model_config.get_head_size()

        self.cpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = []
        self.tpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = []
        tpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
            num_gpu_blocks, self.block_size, num_kv_heads, head_size)
        cpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
            num_cpu_blocks, self.block_size, num_kv_heads, head_size)
        for _ in range(num_layers):
            tpu_k_cache = torch.zeros(tpu_cache_shape,
                                      dtype=dtype,
                                      device=self.device)
            tpu_v_cache = torch.zeros_like(tpu_k_cache)
            self.tpu_cache.append((tpu_k_cache, tpu_v_cache))
            cpu_k_cache = torch.zeros(cpu_cache_shape,
                                      dtype=dtype,
                                      device="cpu")
            cpu_v_cache = torch.zeros_like(cpu_k_cache)
            self.cpu_cache.append((cpu_k_cache, cpu_v_cache))
        bind_kv_cache(self.compilation_config.static_forward_context,
                      [self.tpu_cache])
        self._warmup_model()

    def _warmup_model(self) -> None:
        # FIXME(woosuk): Here we are abusing `enforce_eager` which is defined
        # for CUDA graphs. We should refactor this part.
        if not self.model_config.enforce_eager:
            # Warm up the model with all possible input shapes so that
            # compilation never happens during the actual execution.
            # This may take ~30 mins for the first run and ~20 mins for the
            # subsequent runs.
            # If `enforce_eager` is True, the ahead-of-time compilation is
            # skipped and the compilation happens during the actual execution,
            # which is bad for performance but useful for development.
            self.model_runner.warmup_model(self.tpu_cache)

    def get_cache_block_size_bytes(self) -> int:
        head_size = self.model_config.get_head_size()
        num_heads = self.model_config.get_num_kv_heads(self.parallel_config)
        num_layers = self.model_config.get_num_layers(self.parallel_config)

        key_cache_block = self.cache_config.block_size * num_heads * head_size
        value_cache_block = key_cache_block
        total = num_layers * (key_cache_block + value_cache_block)
        dtype_size = get_dtype_size(self.cache_dtype)
        return dtype_size * total

    @property
    def do_metadata_broadcast(self) -> bool:
        return self.parallel_config.tensor_parallel_size > 1

    @property
    def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
        # NOTE(woosuk): This assumes virtual_engine == 0, i.e., no pipeline
        # parallelism.
        return [self.tpu_cache]

    def prepare_worker_input(
        self,
        execute_model_req: ExecuteModelRequest,
    ) -> WorkerInput:
        virtual_engine = execute_model_req.virtual_engine
        num_seq_groups = len(execute_model_req.seq_group_metadata_list)
        blocks_to_swap_in = _make_src_to_dst(
            execute_model_req.blocks_to_swap_in, "cpu", self.device)
        blocks_to_swap_out = _make_src_to_dst(
            execute_model_req.blocks_to_swap_out, self.device, "cpu")
        blocks_to_copy = _make_src_to_dst(execute_model_req.blocks_to_copy,
                                          self.device, self.device)
        return WorkerInput(
            num_seq_groups=num_seq_groups,
            blocks_to_swap_in=blocks_to_swap_in,
            blocks_to_swap_out=blocks_to_swap_out,
            blocks_to_copy=blocks_to_copy,
            virtual_engine=virtual_engine,
        )

    def execute_worker(self, worker_input: WorkerInput) -> None:
        virtual_engine = worker_input.virtual_engine
        assert virtual_engine == 0
        attn_backend = self.model_runner.attn_backend
        num_layers = self.model_config.get_num_layers(self.parallel_config)

        # Issue cache operations.
        if worker_input.blocks_to_swap_in is not None:
            src_indices, dst_indices = worker_input.blocks_to_swap_in
            if src_indices.numel() > 0:
                # Swap from CPU to TPU.
                for i in range(num_layers):
                    tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
                    cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
                    k = cpu_k_cache[:, src_indices].to(self.device)
                    v = cpu_v_cache[:, src_indices].to(self.device)
                    _insert_kv(k, v, dst_indices, tpu_k_cache, tpu_v_cache)

        if worker_input.blocks_to_swap_out is not None:
            src_indices, dst_indices = worker_input.blocks_to_swap_out
            if src_indices.numel() > 0:
                # Swap from TPU to CPU.
                for i in range(num_layers):
                    tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
                    cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
                    cpu_k_cache[:, dst_indices] = tpu_k_cache[:, src_indices]
                    cpu_v_cache[:, dst_indices] = tpu_v_cache[:, src_indices]

        if worker_input.blocks_to_copy is not None:
            src_indices, dst_indices = worker_input.blocks_to_copy
            if src_indices.numel() > 0:
                attn_backend.copy_blocks(self.tpu_cache,
                                         (src_indices, dst_indices))


def _make_src_to_dst(
    mapping: List[Tuple[int, int]],
    src_device: Union[torch.device, str],
    dst_device: Union[torch.device, str],
) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
    if not mapping:
        return None

    src_indices = [i for i, _ in mapping]
    dst_indices = [i for _, i in mapping]
    src_indices = torch.tensor(src_indices,
                               device=src_device,
                               dtype=torch.int64)
    dst_indices = torch.tensor(dst_indices,
                               device=dst_device,
                               dtype=torch.int64)
    return src_indices, dst_indices


@torch.compile(backend="openxla")
def _insert_kv(
    k: torch.Tensor,
    v: torch.Tensor,
    indices: torch.Tensor,
    tpu_k_cache: torch.Tensor,
    tpu_v_cache: torch.Tensor,
) -> None:
    torch.ops.xla.dynamo_set_buffer_donor_(tpu_k_cache, True)
    torch.ops.xla.dynamo_set_buffer_donor_(tpu_v_cache, True)
    tpu_k_cache[:, indices] = k
    tpu_v_cache[:, indices] = v